mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-30 19:53:44 +00:00
Folder restructure for evals/datasets/scoring (#419)
* rename evals related stuff * fix datasetio * fix scoring test * localfs -> LocalFS * refactor scoring * refactor scoring * remove 8b_correctness scoring_fn from tests * tests w/ eval params * scoring fn braintrust fixture * import
This commit is contained in:
parent
2b7d70ba86
commit
b4416b72fd
37 changed files with 141 additions and 100 deletions
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import MetaReferenceDatasetIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceDatasetIOConfig,
|
||||
_deps,
|
||||
):
|
||||
from .datasetio import MetaReferenceDatasetIOImpl
|
||||
|
||||
impl = MetaReferenceDatasetIOImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,9 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.datasetio import * # noqa: F401, F403
|
||||
|
||||
|
||||
class MetaReferenceDatasetIOConfig(BaseModel): ...
|
|
@ -1,133 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
|
||||
from .config import MetaReferenceDatasetIOConfig
|
||||
|
||||
|
||||
class BaseDataset(ABC):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, idx):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetInfo:
|
||||
dataset_def: DatasetDef
|
||||
dataset_impl: BaseDataset
|
||||
|
||||
|
||||
class PandasDataframeDataset(BaseDataset):
|
||||
def __init__(self, dataset_def: DatasetDef, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_def = dataset_def
|
||||
self.df = None
|
||||
|
||||
def __len__(self) -> int:
|
||||
assert self.df is not None, "Dataset not loaded. Please call .load() first"
|
||||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
assert self.df is not None, "Dataset not loaded. Please call .load() first"
|
||||
if isinstance(idx, slice):
|
||||
return self.df.iloc[idx].to_dict(orient="records")
|
||||
else:
|
||||
return self.df.iloc[idx].to_dict()
|
||||
|
||||
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
|
||||
# note that we will drop any columns in dataset that are not in the schema
|
||||
df = df[self.dataset_def.dataset_schema.keys()]
|
||||
# check all columns in dataset schema are present
|
||||
assert len(df.columns) == len(self.dataset_def.dataset_schema)
|
||||
# TODO: type checking against column types in dataset schema
|
||||
return df
|
||||
|
||||
def load(self) -> None:
|
||||
if self.df is not None:
|
||||
return
|
||||
|
||||
df = get_dataframe_from_url(self.dataset_def.url)
|
||||
if df is None:
|
||||
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
||||
|
||||
self.df = self._validate_dataset_schema(df)
|
||||
|
||||
|
||||
class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||
def __init__(self, config: MetaReferenceDatasetIOConfig) -> None:
|
||||
self.config = config
|
||||
# local registry for keeping track of datasets within the provider
|
||||
self.dataset_infos = {}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset_def: DatasetDef,
|
||||
) -> None:
|
||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||
self.dataset_infos[dataset_def.identifier] = DatasetInfo(
|
||||
dataset_def=dataset_def,
|
||||
dataset_impl=dataset_impl,
|
||||
)
|
||||
|
||||
async def list_datasets(self) -> List[DatasetDef]:
|
||||
return [i.dataset_def for i in self.dataset_infos.values()]
|
||||
|
||||
async def get_rows_paginated(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
dataset_info = self.dataset_infos.get(dataset_id)
|
||||
dataset_info.dataset_impl.load()
|
||||
|
||||
if page_token and not page_token.isnumeric():
|
||||
raise ValueError("Invalid page_token")
|
||||
|
||||
if page_token is None or len(page_token) == 0:
|
||||
next_page_token = 0
|
||||
else:
|
||||
next_page_token = int(page_token)
|
||||
|
||||
start = next_page_token
|
||||
if rows_in_page == -1:
|
||||
end = len(dataset_info.dataset_impl)
|
||||
else:
|
||||
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
|
||||
|
||||
rows = dataset_info.dataset_impl[start:end]
|
||||
|
||||
return PaginatedRowsResult(
|
||||
rows=rows,
|
||||
total_count=len(rows),
|
||||
next_page_token=str(end),
|
||||
)
|
|
@ -1,27 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceEvalConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
):
|
||||
from .eval import MetaReferenceEvalImpl
|
||||
|
||||
impl = MetaReferenceEvalImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
deps[Api.datasets],
|
||||
deps[Api.scoring],
|
||||
deps[Api.inference],
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,9 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.eval import * # noqa: F401, F403
|
||||
|
||||
|
||||
class MetaReferenceEvalConfig(BaseModel): ...
|
|
@ -1,205 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from enum import Enum
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from .....apis.common.job_types import Job
|
||||
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from tqdm import tqdm
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval_tasks import EvalTaskDef
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
|
||||
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
|
||||
class ColumnName(Enum):
|
||||
input_query = "input_query"
|
||||
expected_answer = "expected_answer"
|
||||
chat_completion_input = "chat_completion_input"
|
||||
completion_input = "completion_input"
|
||||
generated_answer = "generated_answer"
|
||||
|
||||
|
||||
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceEvalConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
scoring_api: Scoring,
|
||||
inference_api: Inference,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.scoring_api = scoring_api
|
||||
self.inference_api = inference_api
|
||||
|
||||
# TODO: assume sync job, will need jobs API for async scheduling
|
||||
self.jobs = {}
|
||||
|
||||
self.eval_tasks = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def register_eval_task(self, task_def: EvalTaskDef) -> None:
|
||||
self.eval_tasks[task_def.identifier] = task_def
|
||||
|
||||
async def list_eval_tasks(self) -> List[EvalTaskDef]:
|
||||
return list(self.eval_tasks.values())
|
||||
|
||||
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||
|
||||
expected_schemas = [
|
||||
{
|
||||
ColumnName.input_query.value: StringType(),
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
|
||||
},
|
||||
{
|
||||
ColumnName.input_query.value: StringType(),
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.completion_input.value: CompletionInputType(),
|
||||
},
|
||||
]
|
||||
|
||||
if dataset_def.dataset_schema not in expected_schemas:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
||||
)
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
task_id: str,
|
||||
task_config: EvalTaskConfig,
|
||||
) -> Job:
|
||||
task_def = self.eval_tasks[task_id]
|
||||
dataset_id = task_def.dataset_id
|
||||
candidate = task_config.eval_candidate
|
||||
scoring_functions = task_def.scoring_functions
|
||||
|
||||
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=(
|
||||
-1 if task_config.num_examples is None else task_config.num_examples
|
||||
),
|
||||
)
|
||||
res = await self.evaluate_rows(
|
||||
task_id=task_id,
|
||||
input_rows=all_rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
task_config=task_config,
|
||||
)
|
||||
|
||||
# TODO: currently needs to wait for generation before returning
|
||||
# need job scheduler queue (ray/celery) w/ jobs api
|
||||
job_id = str(len(self.jobs))
|
||||
self.jobs[job_id] = res
|
||||
return Job(job_id=job_id)
|
||||
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: EvalTaskConfig,
|
||||
) -> EvaluateResponse:
|
||||
candidate = task_config.eval_candidate
|
||||
if candidate.type == "agent":
|
||||
raise NotImplementedError(
|
||||
"Evaluation with generation has not been implemented for agents"
|
||||
)
|
||||
assert (
|
||||
candidate.sampling_params.max_tokens is not None
|
||||
), "SamplingParams.max_tokens must be provided"
|
||||
|
||||
generations = []
|
||||
for x in tqdm(input_rows):
|
||||
if ColumnName.completion_input.value in x:
|
||||
input_content = eval(str(x[ColumnName.completion_input.value]))
|
||||
response = await self.inference_api.completion(
|
||||
model=candidate.model,
|
||||
content=input_content,
|
||||
sampling_params=candidate.sampling_params,
|
||||
)
|
||||
generations.append(
|
||||
{
|
||||
ColumnName.generated_answer.value: response.completion_message.content
|
||||
}
|
||||
)
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
chat_completion_input_str = str(
|
||||
x[ColumnName.chat_completion_input.value]
|
||||
)
|
||||
input_messages = eval(chat_completion_input_str)
|
||||
input_messages = [UserMessage(**x) for x in input_messages]
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
messages += input_messages
|
||||
response = await self.inference_api.chat_completion(
|
||||
model=candidate.model,
|
||||
messages=messages,
|
||||
sampling_params=candidate.sampling_params,
|
||||
)
|
||||
generations.append(
|
||||
{
|
||||
ColumnName.generated_answer.value: response.completion_message.content
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid input row")
|
||||
|
||||
# scoring with generated_answer
|
||||
score_input_rows = [
|
||||
input_r | generated_r
|
||||
for input_r, generated_r in zip(input_rows, generations)
|
||||
]
|
||||
|
||||
if task_config.type == "app" and task_config.scoring_params is not None:
|
||||
scoring_functions_dict = {
|
||||
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
|
||||
for scoring_fn_id in scoring_functions
|
||||
}
|
||||
else:
|
||||
scoring_functions_dict = {
|
||||
scoring_fn_id: None for scoring_fn_id in scoring_functions
|
||||
}
|
||||
|
||||
score_response = await self.scoring_api.score(
|
||||
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
||||
)
|
||||
|
||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||
|
||||
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]:
|
||||
if job_id in self.jobs:
|
||||
return JobStatus.completed
|
||||
|
||||
return None
|
||||
|
||||
async def job_cancel(self, task_id: str, job_id: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
|
||||
async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse:
|
||||
status = await self.job_status(task_id, job_id)
|
||||
if not status or status != JobStatus.completed:
|
||||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||
|
||||
return self.jobs[job_id]
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MetaReferenceScoringConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
):
|
||||
from .scoring import MetaReferenceScoringImpl
|
||||
|
||||
impl = MetaReferenceScoringImpl(
|
||||
config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference]
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,9 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
|
||||
|
||||
class MetaReferenceScoringConfig(BaseModel): ...
|
|
@ -1,135 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import List
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
|
||||
from .config import MetaReferenceScoringConfig
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
|
||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||
|
||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
|
||||
|
||||
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
|
||||
|
||||
|
||||
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceScoringConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
inference_api: Inference,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.inference_api = inference_api
|
||||
self.scoring_fn_id_impls = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
for x in FIXED_FNS:
|
||||
impl = x()
|
||||
for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||
for x in LLM_JUDGE_FNS:
|
||||
impl = x(inference_api=self.inference_api)
|
||||
for fn_defs in impl.get_supported_scoring_fn_defs():
|
||||
self.scoring_fn_id_impls[fn_defs.identifier] = impl
|
||||
self.llm_as_judge_fn = impl
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
||||
scoring_fn_defs_list = [
|
||||
fn_def
|
||||
for impl in self.scoring_fn_id_impls.values()
|
||||
for fn_def in impl.get_supported_scoring_fn_defs()
|
||||
]
|
||||
|
||||
for f in scoring_fn_defs_list:
|
||||
assert f.identifier.startswith(
|
||||
"meta-reference"
|
||||
), "All meta-reference scoring fn must have identifier prefixed with 'meta-reference'! "
|
||||
|
||||
return scoring_fn_defs_list
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
||||
raise NotImplementedError("Register scoring function not implemented yet")
|
||||
|
||||
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
|
||||
)
|
||||
|
||||
for required_column in ["generated_answer", "expected_answer", "input_query"]:
|
||||
if required_column not in dataset_def.dataset_schema:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a '{required_column}' column."
|
||||
)
|
||||
if dataset_def.dataset_schema[required_column].type != "string":
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
|
||||
)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
)
|
||||
res = await self.score(
|
||||
input_rows=all_rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
if save_results_dataset:
|
||||
# TODO: persist and register dataset on to server for reading
|
||||
# self.datasets_api.register_dataset()
|
||||
raise NotImplementedError("Save results dataset not implemented yet")
|
||||
|
||||
return ScoreBatchResponse(
|
||||
results=res.results,
|
||||
)
|
||||
|
||||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions.keys():
|
||||
if scoring_fn_id not in self.scoring_fn_id_impls:
|
||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
||||
score_results = await scoring_fn.score(
|
||||
input_rows, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
agg_results = await scoring_fn.aggregate(score_results)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
aggregated_results=agg_results,
|
||||
)
|
||||
|
||||
return ScoreResponse(
|
||||
results=res,
|
||||
)
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,61 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
|
||||
|
||||
class BaseScoringFn(ABC):
|
||||
"""
|
||||
Base interface class for all meta-reference scoring_fns.
|
||||
Each scoring_fn needs to implement the following methods:
|
||||
- score_row(self, row)
|
||||
- aggregate(self, scoring_fn_results)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]:
|
||||
return [x for x in self.supported_fn_defs_registry.values()]
|
||||
|
||||
def register_scoring_fn_def(self, scoring_fn_def: ScoringFnDef) -> None:
|
||||
if scoring_fn_def.identifier in self.supported_fn_defs_registry:
|
||||
raise ValueError(
|
||||
f"Scoring function def with identifier {scoring_fn_def.identifier} already exists."
|
||||
)
|
||||
self.supported_fn_defs_registry[scoring_fn_def.identifier] = scoring_fn_def
|
||||
|
||||
@abstractmethod
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> List[ScoringResultRow]:
|
||||
return [
|
||||
await self.score_row(input_row, scoring_fn_identifier, scoring_params)
|
||||
for input_row in input_rows
|
||||
]
|
|
@ -1,31 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
|
||||
FN_DEFS_PATH = Path(__file__).parent / "fn_defs"
|
||||
|
||||
|
||||
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
num_correct = sum(result["score"] for result in scoring_results)
|
||||
avg_score = num_correct / len(scoring_results)
|
||||
|
||||
return {
|
||||
"accuracy": avg_score,
|
||||
"num_correct": num_correct,
|
||||
"num_total": len(scoring_results),
|
||||
}
|
||||
|
||||
|
||||
def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
return {
|
||||
"average": sum(
|
||||
result["score"] for result in scoring_results if result["score"] is not None
|
||||
)
|
||||
/ len([_ for _ in scoring_results if _["score"] is not None]),
|
||||
}
|
|
@ -1,55 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
BaseScoringFn,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_accuracy,
|
||||
)
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.fn_defs.equality import (
|
||||
equality,
|
||||
)
|
||||
|
||||
|
||||
class EqualityScoringFn(BaseScoringFn):
|
||||
"""
|
||||
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
equality.identifier: equality,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "equality",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||
assert (
|
||||
"generated_answer" in input_row
|
||||
), "Generated answer not found in input row."
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
score = 1.0 if expected_answer == generated_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
return aggregate_accuracy(scoring_results)
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||
|
||||
|
||||
equality = ScoringFnDef(
|
||||
identifier="meta-reference::equality",
|
||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||
return_type=NumberType(),
|
||||
)
|
|
@ -1,39 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
|
||||
JUDGE_PROMPT = """
|
||||
You will be given a question, a expected_answer, and a system_answer.
|
||||
Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.
|
||||
Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.
|
||||
Provide your feedback as follows:
|
||||
Feedback:::
|
||||
Total rating: (your rating, as a int between 0 and 5)
|
||||
Now here are the question, expected_answer, system_answer.
|
||||
Question: {input_query}
|
||||
Expected Answer: {expected_answer}
|
||||
System Answer: {generated_answer}
|
||||
Feedback:::
|
||||
Total rating:
|
||||
"""
|
||||
|
||||
llm_as_judge_8b_correctness = ScoringFnDef(
|
||||
identifier="meta-reference::llm_as_judge_8b_correctness",
|
||||
description="Llm As Judge Scoring Function",
|
||||
return_type=NumberType(),
|
||||
params=LLMAsJudgeScoringFnParams(
|
||||
prompt_template=JUDGE_PROMPT,
|
||||
judge_model="Llama3.1-8B-Instruct",
|
||||
judge_score_regexes=[
|
||||
r"Total rating: (\d+)",
|
||||
r"rating: (\d+)",
|
||||
r"Rating: (\d+)",
|
||||
],
|
||||
),
|
||||
)
|
|
@ -1,69 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
|
||||
MULTILINGUAL_ANSWER_REGEXES = [
|
||||
r"Answer\s*:",
|
||||
r"Answer\s*:", # Korean invisible character
|
||||
r"উত্তর\s*:",
|
||||
r"उत्तर\s*:",
|
||||
r"উত্তরঃ",
|
||||
r"উত্তর\s*:",
|
||||
r"Antwort\s*:",
|
||||
r"답변\s*:",
|
||||
r"정답\s*:",
|
||||
r"답\s*:",
|
||||
r"答案\s*:",
|
||||
r"答案\s*:",
|
||||
r"答\s*:",
|
||||
r"答\s*:",
|
||||
r"答复\s*:",
|
||||
r"答曰\s*:",
|
||||
r"الإجابة:",
|
||||
r"الجواب:",
|
||||
r"إجابة:",
|
||||
r"الإجابة النهائية:",
|
||||
r"الإجابة الصحيحة:",
|
||||
r"الإجابة الصحيحة هي:",
|
||||
r"الإجابة هي:",
|
||||
r"Respuesta\s*:",
|
||||
r"Risposta\s*:",
|
||||
r"答え\s*:",
|
||||
r"答え\s*:",
|
||||
r"回答\s*:",
|
||||
r"回答\s*:",
|
||||
r"解答\s*:",
|
||||
r"Jawaban\s*:",
|
||||
r"Réponse\s*:",
|
||||
r"Resposta\s*:",
|
||||
r"Jibu\s*:",
|
||||
r"Idahun\s*:",
|
||||
r"Ìdáhùn\s*:",
|
||||
r"Idáhùn\s*:",
|
||||
r"Àmọ̀nà\s*:",
|
||||
r"Àdáhùn\s*:",
|
||||
r"Ànúgọ\s*:",
|
||||
r"Àṣàyàn\s*:",
|
||||
]
|
||||
|
||||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
|
||||
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
|
||||
)
|
||||
|
||||
regex_parser_multiple_choice_answer = ScoringFnDef(
|
||||
identifier="meta-reference::regex_parser_multiple_choice_answer",
|
||||
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
|
||||
return_type=NumberType(),
|
||||
params=RegexParserScoringFnParams(
|
||||
parsing_regexes=[
|
||||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
|
||||
for x in MULTILINGUAL_ANSWER_REGEXES
|
||||
],
|
||||
),
|
||||
)
|
|
@ -1,16 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||
|
||||
|
||||
subset_of = ScoringFnDef(
|
||||
identifier="meta-reference::subset_of",
|
||||
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
||||
parameters=[],
|
||||
return_type=NumberType(),
|
||||
)
|
|
@ -1,95 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
BaseScoringFn,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
import re
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_average,
|
||||
)
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import (
|
||||
llm_as_judge_8b_correctness,
|
||||
)
|
||||
|
||||
|
||||
class LlmAsJudgeScoringFn(BaseScoringFn):
|
||||
"""
|
||||
A scoring_fn that assigns
|
||||
"""
|
||||
|
||||
def __init__(self, inference_api: Inference, *arg, **kwargs) -> None:
|
||||
super().__init__(*arg, **kwargs)
|
||||
self.inference_api = inference_api
|
||||
self.supported_fn_defs_registry = {
|
||||
llm_as_judge_8b_correctness.identifier: llm_as_judge_8b_correctness,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert (
|
||||
scoring_fn_identifier is not None
|
||||
), "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
|
||||
# override params if scoring_params is provided
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
|
||||
assert (
|
||||
fn_def.params.prompt_template is not None
|
||||
), "LLM Judge prompt_template not found."
|
||||
assert (
|
||||
fn_def.params.judge_score_regexes is not None
|
||||
), "LLM Judge judge_score_regexes not found."
|
||||
|
||||
input_query = input_row["input_query"]
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
||||
judge_input_msg = fn_def.params.prompt_template.format(
|
||||
input_query=input_query,
|
||||
expected_answer=expected_answer,
|
||||
generated_answer=generated_answer,
|
||||
)
|
||||
|
||||
judge_response = await self.inference_api.chat_completion(
|
||||
model=fn_def.params.judge_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": judge_input_msg,
|
||||
}
|
||||
],
|
||||
)
|
||||
content = judge_response.completion_message.content
|
||||
rating_regexes = fn_def.params.judge_score_regexes
|
||||
|
||||
judge_rating = None
|
||||
for regex in rating_regexes:
|
||||
match = re.search(regex, content)
|
||||
if match:
|
||||
judge_rating = int(match.group(1))
|
||||
break
|
||||
|
||||
return {
|
||||
"score": judge_rating,
|
||||
"judge_feedback": content,
|
||||
}
|
||||
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
return aggregate_average(scoring_results)
|
|
@ -1,67 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import re
|
||||
|
||||
from .base_scoring_fn import BaseScoringFn
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from .common import aggregate_accuracy
|
||||
|
||||
from .fn_defs.regex_parser_multiple_choice_answer import (
|
||||
regex_parser_multiple_choice_answer,
|
||||
)
|
||||
|
||||
|
||||
class RegexParserScoringFn(BaseScoringFn):
|
||||
"""
|
||||
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
regex_parser_multiple_choice_answer.identifier: regex_parser_multiple_choice_answer,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert (
|
||||
scoring_fn_identifier is not None
|
||||
), "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
assert (
|
||||
fn_def.params is not None
|
||||
and fn_def.params.type == ScoringConfigType.regex_parser.value
|
||||
), f"RegexParserScoringFnParams not found for {fn_def}."
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
||||
# parse answer according to regex
|
||||
parsed_answer = None
|
||||
for regex in fn_def.params.parsing_regexes:
|
||||
match = re.search(regex, generated_answer)
|
||||
if match:
|
||||
parsed_answer = match.group(1)
|
||||
break
|
||||
|
||||
score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
return aggregate_accuracy(scoring_results)
|
|
@ -1,49 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
BaseScoringFn,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_accuracy,
|
||||
)
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.scoring.scoring_fn.fn_defs.subset_of import (
|
||||
subset_of,
|
||||
)
|
||||
|
||||
|
||||
class SubsetOfScoringFn(BaseScoringFn):
|
||||
"""
|
||||
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
subset_of.identifier: subset_of,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "subset_of",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
score = 1.0 if expected_answer in generated_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
return aggregate_accuracy(scoring_results)
|
Loading…
Add table
Add a link
Reference in a new issue