diff --git a/distributions/dependencies.json b/distributions/dependencies.json index 1767523d6..15826d40b 100644 --- a/distributions/dependencies.json +++ b/distributions/dependencies.json @@ -426,38 +426,6 @@ "transformers", "uvicorn" ], - "open-benchmark": [ - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "fastapi", - "fire", - "httpx", - "litellm", - "matplotlib", - "mcp", - "nltk", - "numpy", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "sqlite-vec", - "together", - "tqdm", - "transformers", - "uvicorn" - ], "passthrough": [ "aiosqlite", "blobfile", diff --git a/llama_stack/providers/inline/eval/__init__.py b/llama_stack/providers/inline/eval/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/inline/eval/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/eval/meta_reference/__init__.py b/llama_stack/providers/inline/eval/meta_reference/__init__.py deleted file mode 100644 index 576a5682b..000000000 --- a/llama_stack/providers/inline/eval/meta_reference/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from typing import Any, Dict - -from llama_stack.distribution.datatypes import Api - -from .config import MetaReferenceEvalConfig - - -async def get_provider_impl( - config: MetaReferenceEvalConfig, - deps: Dict[Api, Any], -): - from .eval import MetaReferenceEvalImpl - - impl = MetaReferenceEvalImpl( - config, - deps[Api.datasetio], - deps[Api.datasets], - deps[Api.inference], - deps[Api.agents], - ) - await impl.initialize() - return impl diff --git a/llama_stack/providers/inline/eval/meta_reference/config.py b/llama_stack/providers/inline/eval/meta_reference/config.py deleted file mode 100644 index 5b2bec259..000000000 --- a/llama_stack/providers/inline/eval/meta_reference/config.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from typing import Any, Dict - -from pydantic import BaseModel - -from llama_stack.providers.utils.kvstore.config import ( - KVStoreConfig, - SqliteKVStoreConfig, -) - - -class MetaReferenceEvalConfig(BaseModel): - kvstore: KVStoreConfig - - @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: - return { - "kvstore": SqliteKVStoreConfig.sample_run_config( - __distro_dir__=__distro_dir__, - db_name="meta_reference_eval.db", - ) - } diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py deleted file mode 100644 index 6940ad743..000000000 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ /dev/null @@ -1,233 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -import json -from typing import Any, Dict, List - -from tqdm import tqdm - -from llama_stack.apis.agents import Agents, StepType -from llama_stack.apis.benchmarks import Benchmark -from llama_stack.apis.datasetio import DatasetIO -from llama_stack.apis.datasets import Datasets -from llama_stack.apis.inference import Inference, SystemMessage, UserMessage -from llama_stack.providers.datatypes import BenchmarksProtocolPrivate -from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( - MEMORY_QUERY_TOOL, -) -from llama_stack.providers.utils.common.data_schema_validator import ColumnName -from llama_stack.providers.utils.kvstore import kvstore_impl - -from .....apis.common.job_types import Job, JobStatus -from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse -from .config import MetaReferenceEvalConfig - -EVAL_TASKS_PREFIX = "benchmarks:" - - -class MetaReferenceEvalImpl( - Eval, - BenchmarksProtocolPrivate, -): - def __init__( - self, - config: MetaReferenceEvalConfig, - datasetio_api: DatasetIO, - datasets_api: Datasets, - inference_api: Inference, - agents_api: Agents, - ) -> None: - self.config = config - self.datasetio_api = datasetio_api - self.datasets_api = datasets_api - # TODO(xiyan): this implementation will be refactored - self.scoring_api = None - self.inference_api = inference_api - self.agents_api = agents_api - - # TODO: assume sync job, will need jobs API for async scheduling - self.jobs = {} - - self.benchmarks = {} - - async def initialize(self) -> None: - self.kvstore = await kvstore_impl(self.config.kvstore) - # Load existing benchmarks from kvstore - start_key = EVAL_TASKS_PREFIX - end_key = f"{EVAL_TASKS_PREFIX}\xff" - stored_benchmarks = await self.kvstore.range(start_key, end_key) - - for benchmark in stored_benchmarks: - benchmark = Benchmark.model_validate_json(benchmark) - self.benchmarks[benchmark.identifier] = benchmark - - async def shutdown(self) -> None: ... - - async def register_benchmark(self, task_def: Benchmark) -> None: - # Store in kvstore - key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}" - await self.kvstore.set( - key=key, - value=task_def.model_dump_json(), - ) - self.benchmarks[task_def.identifier] = task_def - - async def run_eval( - self, - benchmark_id: str, - benchmark_config: BenchmarkConfig, - ) -> Job: - task_def = self.benchmarks[benchmark_id] - dataset_id = task_def.dataset_id - scoring_functions = task_def.scoring_functions - - # TODO (xiyan): validate dataset schema - # dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - - all_rows = await self.datasetio_api.iterrows( - dataset_id=dataset_id, - limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples), - ) - res = await self.evaluate_rows( - benchmark_id=benchmark_id, - input_rows=all_rows.data, - scoring_functions=scoring_functions, - benchmark_config=benchmark_config, - ) - - # TODO: currently needs to wait for generation before returning - # need job scheduler queue (ray/celery) w/ jobs api - job_id = str(len(self.jobs)) - self.jobs[job_id] = res - return Job(job_id=job_id, status=JobStatus.completed) - - async def _run_agent_generation( - self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig - ) -> List[Dict[str, Any]]: - candidate = benchmark_config.eval_candidate - create_response = await self.agents_api.create_agent(candidate.config) - agent_id = create_response.agent_id - - generations = [] - for i, x in tqdm(enumerate(input_rows)): - assert ColumnName.chat_completion_input.value in x, "Invalid input row" - input_messages = json.loads(x[ColumnName.chat_completion_input.value]) - input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"] - - # NOTE: only single-turn agent generation is supported. Create a new session for each input row - session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}") - session_id = session_create_response.session_id - - turn_request = dict( - agent_id=agent_id, - session_id=session_id, - messages=input_messages, - stream=True, - ) - turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)] - final_event = turn_response[-1].event.payload - - # check if there's a memory retrieval step and extract the context - memory_rag_context = None - for step in final_event.turn.steps: - if step.step_type == StepType.tool_execution.value: - for tool_response in step.tool_responses: - if tool_response.tool_name == MEMORY_QUERY_TOOL: - memory_rag_context = " ".join(x.text for x in tool_response.content) - - agent_generation = {} - agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content - if memory_rag_context: - agent_generation[ColumnName.context.value] = memory_rag_context - - generations.append(agent_generation) - - return generations - - async def _run_model_generation( - self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig - ) -> List[Dict[str, Any]]: - candidate = benchmark_config.eval_candidate - assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided" - - generations = [] - for x in tqdm(input_rows): - if ColumnName.completion_input.value in x: - input_content = json.loads(x[ColumnName.completion_input.value]) - response = await self.inference_api.completion( - model=candidate.model, - content=input_content, - sampling_params=candidate.sampling_params, - ) - generations.append({ColumnName.generated_answer.value: response.completion_message.content}) - elif ColumnName.chat_completion_input.value in x: - chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value]) - input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"] - messages = [] - if candidate.system_message: - messages.append(candidate.system_message) - messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"] - messages += input_messages - response = await self.inference_api.chat_completion( - model_id=candidate.model, - messages=messages, - sampling_params=candidate.sampling_params, - ) - generations.append({ColumnName.generated_answer.value: response.completion_message.content}) - else: - raise ValueError("Invalid input row") - - return generations - - async def evaluate_rows( - self, - benchmark_id: str, - input_rows: List[Dict[str, Any]], - scoring_functions: List[str], - benchmark_config: BenchmarkConfig, - ) -> EvaluateResponse: - candidate = benchmark_config.eval_candidate - if candidate.type == "agent": - generations = await self._run_agent_generation(input_rows, benchmark_config) - elif candidate.type == "model": - generations = await self._run_model_generation(input_rows, benchmark_config) - else: - raise ValueError(f"Invalid candidate type: {candidate.type}") - - # scoring with generated_answer - score_input_rows = [ - input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False) - ] - - if benchmark_config.scoring_params is not None: - scoring_functions_dict = { - scoring_fn_id: benchmark_config.scoring_params.get(scoring_fn_id, None) - for scoring_fn_id in scoring_functions - } - else: - scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions} - - score_response = await self.scoring_api.score( - input_rows=score_input_rows, scoring_functions=scoring_functions_dict - ) - - return EvaluateResponse(generations=generations, scores=score_response.results) - - async def job_status(self, benchmark_id: str, job_id: str) -> Job: - if job_id in self.jobs: - return Job(job_id=job_id, status=JobStatus.completed) - - raise ValueError(f"Job {job_id} not found") - - async def job_cancel(self, benchmark_id: str, job_id: str) -> None: - raise NotImplementedError("Job cancel is not implemented yet") - - async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: - job = await self.job_status(benchmark_id, job_id) - status = job.status - if not status or status != JobStatus.completed: - raise ValueError(f"Job is not completed, Status: {status.value}") - - return self.jobs[job_id] diff --git a/llama_stack/providers/inline/scoring/__init__.py b/llama_stack/providers/inline/scoring/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/inline/scoring/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/scoring/basic/__init__.py b/llama_stack/providers/inline/scoring/basic/__init__.py deleted file mode 100644 index 4898b973a..000000000 --- a/llama_stack/providers/inline/scoring/basic/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from typing import Any, Dict - -from llama_stack.distribution.datatypes import Api - -from .config import BasicScoringConfig - - -async def get_provider_impl( - config: BasicScoringConfig, - deps: Dict[Api, Any], -): - from .scoring import BasicScoringImpl - - impl = BasicScoringImpl( - config, - deps[Api.datasetio], - deps[Api.datasets], - ) - await impl.initialize() - return impl diff --git a/llama_stack/providers/inline/scoring/basic/config.py b/llama_stack/providers/inline/scoring/basic/config.py deleted file mode 100644 index 5866be359..000000000 --- a/llama_stack/providers/inline/scoring/basic/config.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from typing import Any, Dict - -from pydantic import BaseModel - - -class BasicScoringConfig(BaseModel): - @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: - return {} diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py deleted file mode 100644 index 9a45f7139..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from typing import Any, Dict, List, Optional - -from llama_stack.apis.datasetio import DatasetIO -from llama_stack.apis.datasets import Datasets -from llama_stack.apis.scoring import ( - ScoreBatchResponse, - ScoreResponse, - Scoring, - ScoringResult, -) -from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams -from llama_stack.distribution.datatypes import Api -from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate -from llama_stack.providers.utils.common.data_schema_validator import ( - get_valid_schemas, - validate_dataset_schema, -) - -from .config import BasicScoringConfig -from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn -from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn -from .scoring_fn.equality_scoring_fn import EqualityScoringFn -from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn -from .scoring_fn.regex_parser_math_response_scoring_fn import ( - RegexParserMathResponseScoringFn, -) -from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn -from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn - -FIXED_FNS = [ - EqualityScoringFn, - SubsetOfScoringFn, - RegexParserScoringFn, - RegexParserMathResponseScoringFn, - BFCLScoringFn, - IfEvalScoringFn, - DocVQAScoringFn, -] - - -class BasicScoringImpl( - Scoring, - ScoringFunctionsProtocolPrivate, -): - def __init__( - self, - config: BasicScoringConfig, - datasetio_api: DatasetIO, - datasets_api: Datasets, - ) -> None: - self.config = config - self.datasetio_api = datasetio_api - self.datasets_api = datasets_api - self.scoring_fn_id_impls = {} - - async def initialize(self) -> None: - for fn in FIXED_FNS: - impl = fn() - for fn_defs in impl.get_supported_scoring_fn_defs(): - self.scoring_fn_id_impls[fn_defs.identifier] = impl - - async def shutdown(self) -> None: ... - - async def list_scoring_functions(self) -> List[ScoringFn]: - scoring_fn_defs_list = [ - fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs() - ] - - for f in scoring_fn_defs_list: - assert f.identifier.startswith("basic"), "All basic scoring fn must have identifier prefixed with 'basic'! " - - return scoring_fn_defs_list - - async def register_scoring_function(self, function_def: ScoringFn) -> None: - raise NotImplementedError("Register scoring function not implemented yet") - - async def score_batch( - self, - dataset_id: str, - scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, - save_results_dataset: bool = False, - ) -> ScoreBatchResponse: - dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) - - all_rows = await self.datasetio_api.iterrows( - dataset_id=dataset_id, - limit=-1, - ) - res = await self.score( - input_rows=all_rows.data, - scoring_functions=scoring_functions, - ) - if save_results_dataset: - # TODO: persist and register dataset on to server for reading - # self.datasets_api.register_dataset() - raise NotImplementedError("Save results dataset not implemented yet") - - return ScoreBatchResponse( - results=res.results, - ) - - async def score( - self, - input_rows: List[Dict[str, Any]], - scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, - ) -> ScoreResponse: - res = {} - for scoring_fn_id in scoring_functions.keys(): - if scoring_fn_id not in self.scoring_fn_id_impls: - raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") - scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] - scoring_fn_params = scoring_functions.get(scoring_fn_id, None) - score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params) - agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params) - res[scoring_fn_id] = ScoringResult( - score_rows=score_results, - aggregated_results=agg_results, - ) - - return ScoreResponse( - results=res, - ) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/__init__.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl_scoring_fn.py deleted file mode 100644 index f37780f3e..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl_scoring_fn.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json -import re -from typing import Any, Dict, Optional - -from llama_stack.apis.scoring import ScoringResultRow -from llama_stack.apis.scoring_functions import ScoringFnParams -from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn - -from ..utils.bfcl.ast_parser import decode_ast -from ..utils.bfcl.checker import ast_checker, is_empty_output -from .fn_defs.bfcl import bfcl - - -def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]: - contain_func_call = False - error = None - error_type = None - checker_result = {} - try: - prediction = decode_ast(x["generated_answer"], x["language"]) or "" - contain_func_call = True - # if not is_function_calling_format_output(prediction): - if is_empty_output(prediction): - contain_func_call = False - error = "Did not output in the specified format. Note: the model_result is wrapped in a string to ensure json serializability." - error_type = "ast_decoder:decoder_wrong_output_format" - else: - checker_result = ast_checker( - json.loads(x["function"]), - prediction, - json.loads(x["ground_truth"]), - x["language"], - test_category=test_category, - model_name="", - ) - except Exception as e: - prediction = "" - error = f"Invalid syntax. Failed to decode AST. {str(e)}" - error_type = "ast_decoder:decoder_failed" - return { - "prediction": prediction, - "contain_func_call": contain_func_call, - "valid": checker_result.get("valid", False), - "error": error or checker_result.get("error", ""), - "error_type": error_type or checker_result.get("error_type", ""), - } - - -def gen_valid(x: Dict[str, Any]) -> Dict[str, float]: - return {"valid": x["valid"]} - - -def gen_relevance_acc(x: Dict[str, Any]) -> Dict[str, float]: - # This function serves for both relevance and irrelevance tests, which share the exact opposite logic. - # If `test_category` is "irrelevance", the model is expected to output no function call. - # No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`). - # If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call. - acc = not x["contain_func_call"] if "irrelevance" in x["id"] else x["contain_func_call"] - return {"valid": float(acc)} - - -class BFCLScoringFn(RegisteredBaseScoringFn): - """ - A scoring_fn for BFCL - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.supported_fn_defs_registry = { - bfcl.identifier: bfcl, - } - - async def score_row( - self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = "bfcl", - scoring_params: Optional[ScoringFnParams] = None, - ) -> ScoringResultRow: - test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"]) - score_result = postprocess(input_row, test_category) - if test_category in {"irrelevance", "live_relevance", "live_irrelevance"}: - score = gen_relevance_acc(score_result)["valid"] - else: - score = gen_valid(score_result)["valid"] - return { - "score": float(score), - } diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/docvqa_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/docvqa_scoring_fn.py deleted file mode 100644 index 84ca55732..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/docvqa_scoring_fn.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json -import re -from typing import Any, Dict, Optional - -from llama_stack.apis.scoring import ScoringResultRow -from llama_stack.apis.scoring_functions import ScoringFnParams -from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn - -from .fn_defs.docvqa import docvqa - -CONTRACTIONS = { - "aint": "ain't", - "arent": "aren't", - "cant": "can't", - "couldve": "could've", - "couldnt": "couldn't", - "couldn'tve": "couldn't've", - "couldnt've": "couldn't've", - "didnt": "didn't", - "doesnt": "doesn't", - "dont": "don't", - "hadnt": "hadn't", - "hadnt've": "hadn't've", - "hadn'tve": "hadn't've", - "hasnt": "hasn't", - "havent": "haven't", - "hed": "he'd", - "hed've": "he'd've", - "he'dve": "he'd've", - "hes": "he's", - "howd": "how'd", - "howll": "how'll", - "hows": "how's", - "Id've": "I'd've", - "I'dve": "I'd've", - "Im": "I'm", - "Ive": "I've", - "isnt": "isn't", - "itd": "it'd", - "itd've": "it'd've", - "it'dve": "it'd've", - "itll": "it'll", - "let's": "let's", - "maam": "ma'am", - "mightnt": "mightn't", - "mightnt've": "mightn't've", - "mightn'tve": "mightn't've", - "mightve": "might've", - "mustnt": "mustn't", - "mustve": "must've", - "neednt": "needn't", - "notve": "not've", - "oclock": "o'clock", - "oughtnt": "oughtn't", - "ow's'at": "'ow's'at", - "'ows'at": "'ow's'at", - "'ow'sat": "'ow's'at", - "shant": "shan't", - "shed've": "she'd've", - "she'dve": "she'd've", - "she's": "she's", - "shouldve": "should've", - "shouldnt": "shouldn't", - "shouldnt've": "shouldn't've", - "shouldn'tve": "shouldn't've", - "somebody'd": "somebodyd", - "somebodyd've": "somebody'd've", - "somebody'dve": "somebody'd've", - "somebodyll": "somebody'll", - "somebodys": "somebody's", - "someoned": "someone'd", - "someoned've": "someone'd've", - "someone'dve": "someone'd've", - "someonell": "someone'll", - "someones": "someone's", - "somethingd": "something'd", - "somethingd've": "something'd've", - "something'dve": "something'd've", - "somethingll": "something'll", - "thats": "that's", - "thered": "there'd", - "thered've": "there'd've", - "there'dve": "there'd've", - "therere": "there're", - "theres": "there's", - "theyd": "they'd", - "theyd've": "they'd've", - "they'dve": "they'd've", - "theyll": "they'll", - "theyre": "they're", - "theyve": "they've", - "twas": "'twas", - "wasnt": "wasn't", - "wed've": "we'd've", - "we'dve": "we'd've", - "weve": "we've", - "werent": "weren't", - "whatll": "what'll", - "whatre": "what're", - "whats": "what's", - "whatve": "what've", - "whens": "when's", - "whered": "where'd", - "wheres": "where's", - "whereve": "where've", - "whod": "who'd", - "whod've": "who'd've", - "who'dve": "who'd've", - "wholl": "who'll", - "whos": "who's", - "whove": "who've", - "whyll": "why'll", - "whyre": "why're", - "whys": "why's", - "wont": "won't", - "wouldve": "would've", - "wouldnt": "wouldn't", - "wouldnt've": "wouldn't've", - "wouldn'tve": "wouldn't've", - "yall": "y'all", - "yall'll": "y'all'll", - "y'allll": "y'all'll", - "yall'd've": "y'all'd've", - "y'alld've": "y'all'd've", - "y'all'dve": "y'all'd've", - "youd": "you'd", - "youd've": "you'd've", - "you'dve": "you'd've", - "youll": "you'll", - "youre": "you're", - "youve": "you've", - "1st": "first", - "2nd": "second", - "3rd": "third", -} -NUMBERS = { - "none": "0", - "zero": "0", - "one": "1", - "two": "2", - "three": "3", - "four": "4", - "five": "5", - "six": "6", - "seven": "7", - "eight": "8", - "nine": "9", - "ten": "10", -} -ARTICLES = [ - "a", - "an", - "the", - "to", - "in", - "from", - "by", -] # Contains a bit more than just articles, but we want to get rid of these elements influencing the accuracy -PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)") -COMMA_STRIP = re.compile(r"(\d)(\,)(\d)") -PUNCTUATION = [ - ";", - r"/", - "[", - "]", - '"', - "{", - "}", - "(", - ")", - "=", - "+", - "\\", - "_", - "-", - ">", - "<", - "@", - "`", - ",", - "?", - "!", -] - - -def normalize_answer(s: str) -> str: - # process punctuation - for p in PUNCTUATION: - if (p + " " in s or " " + p in s) or (re.search(COMMA_STRIP, s) is not None): - s = s.replace(p, "") - else: - s = s.replace(p, " ") - s = PERIOD_STRIP.sub("", s, re.UNICODE) - - # process digits and articles - temp_text = s.lower().split() - out_text = [] - for word in temp_text: - word = NUMBERS.setdefault(word, word) - if word not in ARTICLES: - out_text.append(word) - - # standardize contractions - for word_id, word in enumerate(out_text): - if word in CONTRACTIONS: - out_text[word_id] = CONTRACTIONS[word] - return " ".join(out_text) - - -class DocVQAScoringFn(RegisteredBaseScoringFn): - """ - docvqa basically matches the generated answer against several allowed - choices, but we need to normalize the answer to avoid penalizing - trivial differences - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.supported_fn_defs_registry = { - docvqa.identifier: docvqa, - } - - async def score_row( - self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = "docvqa", - scoring_params: Optional[ScoringFnParams] = None, - ) -> ScoringResultRow: - expected_answers = json.loads(input_row["expected_answer"]) - generated_answer = input_row["generated_answer"] - score = 1.0 if normalize_answer(generated_answer) in [normalize_answer(s) for s in expected_answers] else 0.0 - return { - "score": score, - } diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py deleted file mode 100644 index 0bd6bdd48..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict, Optional - -from llama_stack.apis.scoring import ScoringResultRow -from llama_stack.apis.scoring_functions import ScoringFnParams -from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn - -from .fn_defs.equality import equality - - -class EqualityScoringFn(RegisteredBaseScoringFn): - """ - A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise. - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.supported_fn_defs_registry = { - equality.identifier: equality, - } - - async def score_row( - self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = "equality", - scoring_params: Optional[ScoringFnParams] = None, - ) -> ScoringResultRow: - assert "expected_answer" in input_row, "Expected answer not found in input row." - assert "generated_answer" in input_row, "Generated answer not found in input row." - - expected_answer = input_row["expected_answer"] - generated_answer = input_row["generated_answer"] - score = 1.0 if expected_answer == generated_answer else 0.0 - return { - "score": score, - } diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/__init__.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/bfcl.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/bfcl.py deleted file mode 100644 index 392d92c86..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/bfcl.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - BasicScoringFnParams, - ScoringFn, -) - -bfcl = ScoringFn( - identifier="basic::bfcl", - description="BFCL complex scoring", - return_type=NumberType(), - provider_id="basic", - provider_resource_id="bfcl", - params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]), -) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/docvqa.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/docvqa.py deleted file mode 100644 index aad3dfe26..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/docvqa.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - BasicScoringFnParams, - ScoringFn, -) - -docvqa = ScoringFn( - identifier="basic::docvqa", - description="DocVQA Visual Question & Answer scoring function", - return_type=NumberType(), - provider_id="basic", - provider_resource_id="docvqa", - params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]), -) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py deleted file mode 100644 index 9b24ff791..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - BasicScoringFnParams, - ScoringFn, -) - -equality = ScoringFn( - identifier="basic::equality", - description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", - provider_id="basic", - provider_resource_id="equality", - return_type=NumberType(), - params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]), -) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/ifeval.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/ifeval.py deleted file mode 100644 index adca0791d..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/ifeval.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - BasicScoringFnParams, - ScoringFn, -) - -ifeval = ScoringFn( - identifier="basic::ifeval", - description="Eval intruction follow capacity by checkping how many instructions can be followed in each example", - return_type=NumberType(), - provider_id="basic", - provider_resource_id="ifeval", - params=BasicScoringFnParams( - aggregation_functions=[AggregationFunctionType.weighted_average], - ), -) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py deleted file mode 100644 index 8b1bf5352..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - RegexParserScoringFnParams, - ScoringFn, -) - -MATH_ANSWER_REGEXES = [r".*final answer is:?\s*\$\\boxed{(?P.*)}\$"] - - -regex_parser_math_response = ScoringFn( - identifier="basic::regex_parser_math_response", - description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match", - return_type=NumberType(), - provider_id="basic", - provider_resource_id="regex-parser-math-response", - params=RegexParserScoringFnParams( - parsing_regexes=MATH_ANSWER_REGEXES, - aggregation_functions=[AggregationFunctionType.accuracy], - ), -) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py deleted file mode 100644 index ea04331c9..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - RegexParserScoringFnParams, - ScoringFn, -) - -MULTILINGUAL_ANSWER_REGEXES = [ - r"The best answer is ", - r"Answer\s*:", - r"Answer\s*:​​​​​​", # Korean invisible character - r"উত্তর\s*:", - r"उत्तर\s*:", - r"উত্তরঃ", - r"উত্তর\s*:", - r"Antwort\s*:", - r"답변\s*:", - r"정답\s*:", - r"답\s*:", - r"答案\s*:", - r"答案\s*:", - r"答\s*:", - r"答\s*:", - r"答复\s*:", - r"答曰\s*:", - r"الإجابة:", - r"الجواب:", - r"إجابة:", - r"الإجابة النهائية:", - r"الإجابة الصحيحة:", - r"الإجابة الصحيحة هي:", - r"الإجابة هي:", - r"Respuesta\s*:", - r"Risposta\s*:", - r"答え\s*:", - r"答え\s*:", - r"回答\s*:", - r"回答\s*:", - r"解答\s*:", - r"Jawaban\s*:", - r"Réponse\s*:", - r"Resposta\s*:", - r"Jibu\s*:", - r"Idahun\s*:", - r"Ìdáhùn\s*:", - r"Idáhùn\s*:", - r"Àmọ̀nà\s*:", - r"Àdáhùn\s*:", - r"Ànúgọ\s*:", - r"Àṣàyàn\s*:", -] - -MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])" - -regex_parser_multiple_choice_answer = ScoringFn( - identifier="basic::regex_parser_multiple_choice_answer", - description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result", - return_type=NumberType(), - provider_id="basic", - provider_resource_id="regex-parser-multiple-choice-answer", - params=RegexParserScoringFnParams( - parsing_regexes=[MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x) for x in MULTILINGUAL_ANSWER_REGEXES], - aggregation_functions=[AggregationFunctionType.accuracy], - ), -) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py deleted file mode 100644 index 9cae66fa6..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - BasicScoringFnParams, - ScoringFn, -) - -subset_of = ScoringFn( - identifier="basic::subset_of", - description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.", - return_type=NumberType(), - provider_id="basic", - provider_resource_id="subset-of", - params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]), -) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/ifeval_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/ifeval_scoring_fn.py deleted file mode 100644 index 6ff856684..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/ifeval_scoring_fn.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict, Optional - -from llama_stack.apis.scoring import ScoringResultRow -from llama_stack.apis.scoring_functions import ScoringFnParams -from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn - -from .fn_defs.ifeval import ( - ifeval, -) - - -class IfEvalScoringFn(RegisteredBaseScoringFn): - """ - A scoring_fn Instruction-Following Eval (IFEval) benchmark - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.supported_fn_defs_registry = { - ifeval.identifier: ifeval, - } - - async def score_row( - self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, - ) -> ScoringResultRow: - from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST - - assert scoring_fn_identifier is not None, "Scoring function identifier not found." - fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] - if scoring_params is not None: - fn_def.params = scoring_params - - instruction_list = input_row["instruction_id_list"] - generated_answer = input_row["generated_answer"].strip() - - is_following_list = [] - results = dict( - {k + "_correct": 0.0 for k in INSTRUCTION_LIST}, - **{k + "_total": 0.0 for k in INSTRUCTION_LIST}, - ) - - for index, instruction_id in enumerate(instruction_list): - instruction_cls = INSTRUCTION_DICT[instruction_id] - instruction = instruction_cls(instruction_id) - results[instruction_id + "_total"] += 1.0 - results[instruction_id.split(":")[0] + "_total"] += 1.0 - - clean_input_row = {k: v for k, v in input_row["kwargs"][index].items() if v is not None} - print(clean_input_row) - instruction.build_description(**clean_input_row) - args = instruction.get_instruction_args() - if args and "prompt" in args: - instruction.build_description(prompt=input_row["prompt"]) - - if generated_answer and instruction.check_following(generated_answer): - is_following_list.append(True) - results[instruction_id + "_correct"] += 1.0 - results[instruction_id.split(":")[0] + "_correct"] += 1.0 - else: - is_following_list.append(False) - - if len(is_following_list) == 0: - return { - "score": 0.0, - "weight": 0.0, - } - - return { - "score": float(sum(is_following_list)) / float(len(is_following_list)), - "weight": float(len(is_following_list)), - } diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py deleted file mode 100644 index d6c78a9ac..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from typing import Any, Dict, Optional - -from llama_stack.apis.scoring import ScoringResultRow -from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType -from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn - -from ..utils.math_utils import first_answer, normalize_final_answer, try_evaluate_frac, try_evaluate_latex -from .fn_defs.regex_parser_math_response import ( - regex_parser_math_response, -) - - -class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn): - """ - A scoring_fn for math benchamrks that parses answer from generated response according to context and check match with expected_answer. - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.supported_fn_defs_registry = { - regex_parser_math_response.identifier: regex_parser_math_response, - } - - async def score_row( - self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, - ) -> ScoringResultRow: - assert scoring_fn_identifier is not None, "Scoring function identifier not found." - fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] - if scoring_params is not None: - fn_def.params = scoring_params - - assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, ( - f"RegexParserScoringFnParams not found for {fn_def}." - ) - - expected_answer = input_row["expected_answer"] - generated_answer = input_row["generated_answer"] - - parsing_regexes = fn_def.params.parsing_regexes - assert len(parsing_regexes) == 1, ( - "Only one parsing regex is supported for regex_parser_math_response scoring function." - ) - parsing_regexes = fn_def.params.parsing_regexes[0] - - normalized_generated_answer = normalize_final_answer( - first_answer(generated_answer), - parsing_regexes, - match_first=True, - ) - normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer)) - - normalized_expected_answer = normalize_final_answer(expected_answer, r".*") - normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer)) - - score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0 - return { - "score": score, - } diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py deleted file mode 100644 index 0606a9581..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -import re -from typing import Any, Dict, Optional - -from llama_stack.apis.scoring import ScoringResultRow -from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType -from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn - -from .fn_defs.regex_parser_multiple_choice_answer import ( - regex_parser_multiple_choice_answer, -) - - -class RegexParserScoringFn(RegisteredBaseScoringFn): - """ - A scoring_fn that parses answer from generated response according to context and check match with expected_answer. - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.supported_fn_defs_registry = { - regex_parser_multiple_choice_answer.identifier: regex_parser_multiple_choice_answer, - } - - async def score_row( - self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, - ) -> ScoringResultRow: - assert scoring_fn_identifier is not None, "Scoring function identifier not found." - fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] - if scoring_params is not None: - fn_def.params = scoring_params - - assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, ( - f"RegexParserScoringFnParams not found for {fn_def}." - ) - - expected_answer = input_row["expected_answer"] - generated_answer = input_row["generated_answer"] - - # parse answer according to regex - parsed_answer = None - for regex in fn_def.params.parsing_regexes: - match = re.search(regex, generated_answer) - if match: - parsed_answer = match.group(1) - break - - score = 1.0 if parsed_answer and parsed_answer == expected_answer else 0.0 - return { - "score": score, - } diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py deleted file mode 100644 index 71defc433..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict, Optional - -from llama_stack.apis.scoring import ScoringResultRow -from llama_stack.apis.scoring_functions import ScoringFnParams -from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn - -from .fn_defs.subset_of import subset_of - - -class SubsetOfScoringFn(RegisteredBaseScoringFn): - """ - A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise. - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.supported_fn_defs_registry = { - subset_of.identifier: subset_of, - } - - async def score_row( - self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = "subset_of", - scoring_params: Optional[ScoringFnParams] = None, - ) -> ScoringResultRow: - expected_answer = input_row["expected_answer"] - generated_answer = input_row["generated_answer"] - score = 1.0 if expected_answer in generated_answer else 0.0 - return { - "score": score, - } diff --git a/llama_stack/providers/inline/scoring/basic/utils/bfcl/__init__.py b/llama_stack/providers/inline/scoring/basic/utils/bfcl/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/inline/scoring/basic/utils/bfcl/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/scoring/basic/utils/bfcl/ast_parser.py b/llama_stack/providers/inline/scoring/basic/utils/bfcl/ast_parser.py deleted file mode 100644 index 445cdfc77..000000000 --- a/llama_stack/providers/inline/scoring/basic/utils/bfcl/ast_parser.py +++ /dev/null @@ -1,296 +0,0 @@ -# ruff: noqa -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -import ast - -from .tree_sitter import get_parser - - -def parse_java_function_call(source_code): - if not source_code.endswith(";"): - source_code += ";" # Necessary for the parser not to register an error - parser = get_parser("java") - tree = parser.parse(bytes(source_code, "utf8")) - root_node = tree.root_node - - if root_node.has_error: - raise Exception("Error parsing java the source code.") - - def get_text(node): - """Returns the text represented by the node.""" - return source_code[node.start_byte : node.end_byte] - - def traverse_node(node, nested=False): - if node.type == "string_literal": - if nested: - return get_text(node) - # Strip surrounding quotes from string literals - return get_text(node)[1:-1] - elif node.type == "character_literal": - if nested: - return get_text(node) - # Strip surrounding single quotes from character literals - return get_text(node)[1:-1] - """Traverse the node to collect texts for complex structures.""" - if node.type in [ - "identifier", - "class_literal", - "type_identifier", - "method_invocation", - ]: - return get_text(node) - elif node.type == "array_creation_expression": - # Handle array creation expression specifically - type_node = node.child_by_field_name("type") - value_node = node.child_by_field_name("value") - type_text = traverse_node(type_node, True) - value_text = traverse_node(value_node, True) - return f"new {type_text}[]{value_text}" - elif node.type == "object_creation_expression": - # Handle object creation expression specifically - type_node = node.child_by_field_name("type") - arguments_node = node.child_by_field_name("arguments") - type_text = traverse_node(type_node, True) - if arguments_node: - # Process each argument carefully, avoiding unnecessary punctuation - argument_texts = [] - for child in arguments_node.children: - if child.type not in [ - ",", - "(", - ")", - ]: # Exclude commas and parentheses - argument_text = traverse_node(child, True) - argument_texts.append(argument_text) - arguments_text = ", ".join(argument_texts) - return f"new {type_text}({arguments_text})" - else: - return f"new {type_text}()" - elif node.type == "set": - # Handling sets specifically - items = [traverse_node(n, True) for n in node.children if n.type not in [",", "set"]] - return "{" + ", ".join(items) + "}" - - elif node.child_count > 0: - return "".join(traverse_node(child, True) for child in node.children) - else: - return get_text(node) - - def extract_arguments(args_node): - arguments = {} - for child in args_node.children: - if child.type == "assignment_expression": - # For named parameters - name_node, value_node = child.children[0], child.children[2] - name = get_text(name_node) - value = traverse_node(value_node) - if name in arguments: - if not isinstance(arguments[name], list): - arguments[name] = [arguments[name]] - arguments[name].append(value) - else: - arguments[name] = value - # arguments.append({'name': name, 'value': value}) - elif child.type in ["identifier", "class_literal", "set"]: - # For unnamed parameters and handling sets - value = traverse_node(child) - if None in arguments: - if not isinstance(arguments[None], list): - arguments[None] = [arguments[None]] - arguments[None].append(value) - else: - arguments[None] = value - return arguments - - def traverse(node): - if node.type == "method_invocation": - # Extract the function name and its arguments - method_name = get_text(node.child_by_field_name("name")) - class_name_node = node.child_by_field_name("object") - if class_name_node: - class_name = get_text(class_name_node) - function_name = f"{class_name}.{method_name}" - else: - function_name = method_name - arguments_node = node.child_by_field_name("arguments") - if arguments_node: - arguments = extract_arguments(arguments_node) - for key, value in arguments.items(): - if isinstance(value, list): - raise Exception("Error: Multiple arguments with the same name are not supported.") - return [{function_name: arguments}] - - else: - for child in node.children: - result = traverse(child) - if result: - return result - - result = traverse(root_node) - return result if result else {} - - -def parse_javascript_function_call(source_code): - if not source_code.endswith(";"): - source_code += ";" # Necessary for the parser not to register an error - parser = get_parser("javascript") - # Parse the source code - tree = parser.parse(bytes(source_code, "utf8")) - root_node = tree.root_node - if root_node.has_error: - raise Exception("Error js parsing the source code.") - - # Function to recursively extract argument details - def extract_arguments(node): - args = {} - for child in node.children: - if child.type == "assignment_expression": - # Extract left (name) and right (value) parts of the assignment - name = child.children[0].text.decode("utf-8") - value = child.children[2].text.decode("utf-8") - if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")): - value = value[1:-1] # Trim the quotation marks - if name in args: - if not isinstance(args[name], list): - args[name] = [args[name]] - args[name].append(value) - else: - args[name] = value - - elif child.type == "identifier" or child.type == "true": - # Handle non-named arguments and boolean values - value = child.text.decode("utf-8") - if None in args: - if not isinstance(args[None], list): - args[None] = [args[None]] - args[None].append(value) - else: - args[None] = value - return args - - # Find the function call and extract its name and arguments - if root_node.type == "program": - for child in root_node.children: - if child.type == "expression_statement": - for sub_child in child.children: - if sub_child.type == "call_expression": - function_name = sub_child.children[0].text.decode("utf8") - arguments_node = sub_child.children[1] - parameters = extract_arguments(arguments_node) - for key, value in parameters.items(): - if isinstance(value, list): - raise Exception("Error: Multiple arguments with the same name are not supported.") - result = [{function_name: parameters}] - return result - - -def ast_parse(input_str, language="Python"): - if language == "Python": - cleaned_input = input_str.strip("[]'") - parsed = ast.parse(cleaned_input, mode="eval") - extracted = [] - if isinstance(parsed.body, ast.Call): - extracted.append(resolve_ast_call(parsed.body)) - else: - for elem in parsed.body.elts: - extracted.append(resolve_ast_call(elem)) - return extracted - elif language == "Java": - return parse_java_function_call(input_str[1:-1]) # Remove the [ and ] from the string - elif language == "JavaScript": - return parse_javascript_function_call(input_str[1:-1]) - else: - raise NotImplementedError(f"Unsupported language: {language}") - - -def resolve_ast_call(elem): - # Handle nested attributes for deeply nested module paths - func_parts = [] - func_part = elem.func - while isinstance(func_part, ast.Attribute): - func_parts.append(func_part.attr) - func_part = func_part.value - if isinstance(func_part, ast.Name): - func_parts.append(func_part.id) - func_name = ".".join(reversed(func_parts)) - args_dict = {} - # Parse when args are simply passed as an unnamed dictionary arg - for arg in elem.args: - if isinstance(arg, ast.Dict): - for key, value in zip(arg.keys, arg.values): - if isinstance(key, ast.Constant): - arg_name = key.value - output = resolve_ast_by_type(value) - args_dict[arg_name] = output - for arg in elem.keywords: - output = resolve_ast_by_type(arg.value) - args_dict[arg.arg] = output - return {func_name: args_dict} - - -def resolve_ast_by_type(value): - if isinstance(value, ast.Constant): - if value.value is Ellipsis: - output = "..." - else: - output = value.value - elif isinstance(value, ast.UnaryOp): - output = -value.operand.value - elif isinstance(value, ast.List): - output = [resolve_ast_by_type(v) for v in value.elts] - elif isinstance(value, ast.Dict): - output = {resolve_ast_by_type(k): resolve_ast_by_type(v) for k, v in zip(value.keys, value.values)} - elif isinstance(value, ast.NameConstant): # Added this condition to handle boolean values - output = value.value - elif isinstance(value, ast.BinOp): # Added this condition to handle function calls as arguments - output = eval(ast.unparse(value)) - elif isinstance(value, ast.Name): - output = value.id - elif isinstance(value, ast.Call): - if len(value.keywords) == 0: - output = ast.unparse(value) - else: - output = resolve_ast_call(value) - elif isinstance(value, ast.Tuple): - output = tuple(resolve_ast_by_type(v) for v in value.elts) - elif isinstance(value, ast.Lambda): - output = eval(ast.unparse(value.body[0].value)) - elif isinstance(value, ast.Ellipsis): - output = "..." - elif isinstance(value, ast.Subscript): - try: - output = ast.unparse(value.body[0].value) - except: - output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]" - else: - raise Exception(f"Unsupported AST type: {type(value)}") - return output - - -def decode_ast(result, language="Python"): - func = result - func = func.replace("\n", "") # remove new line characters - if not func.startswith("["): - func = "[" + func - if not func.endswith("]"): - func = func + "]" - decoded_output = ast_parse(func, language) - return decoded_output - - -def decode_execute(result): - func = result - func = func.replace("\n", "") # remove new line characters - if not func.startswith("["): - func = "[" + func - if not func.endswith("]"): - func = func + "]" - decode_output = ast_parse(func) - execution_list = [] - for function_call in decode_output: - for key, value in function_call.items(): - execution_list.append(f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})") - return execution_list diff --git a/llama_stack/providers/inline/scoring/basic/utils/bfcl/checker.py b/llama_stack/providers/inline/scoring/basic/utils/bfcl/checker.py deleted file mode 100644 index f6aab123c..000000000 --- a/llama_stack/providers/inline/scoring/basic/utils/bfcl/checker.py +++ /dev/null @@ -1,989 +0,0 @@ -# ruff: noqa -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -import json -import re -import time -from typing import Any - -# Comment out for now until we actually use the rest checker in evals -# import requests # Do not remove this import even though it seems to be unused. It's used in the executable_checker_rest function. - - -class NoAPIKeyError(Exception): - def __init__(self): - self.message = "❗️Please fill in the API keys in the function_credential_config.json file. If you do not provide the API keys, the executable test category results will be inaccurate." - super().__init__(self.message) - - -REAL_TIME_MATCH_ALLOWED_DIFFERENCE = 0.2 - - -JAVA_TYPE_CONVERSION = { - "byte": int, - "short": int, - "integer": int, - "float": float, - "double": float, - "long": int, - "boolean": bool, - "char": str, - "Array": list, - "ArrayList": list, - "Set": set, - "HashMap": dict, - "Hashtable": dict, - "Queue": list, # this can be `queue.Queue` as well, for simplicity we check with list - "Stack": list, - "String": str, - "any": str, -} - -JS_TYPE_CONVERSION = { - "String": str, - "integer": int, - "float": float, - "Bigint": int, - "Boolean": bool, - "dict": dict, - "array": list, - "any": str, -} - -# We switch to conditional import for the following two imports to avoid unnecessary installations. -# User doesn't need to setup the tree-sitter packages if they are not running the test for that language. -# from js_type_converter import js_type_converter -# from java_type_converter import java_type_converter - -PYTHON_TYPE_MAPPING = { - "string": str, - "integer": int, - "float": float, - "boolean": bool, - "array": list, - "tuple": list, - "dict": dict, - "any": str, -} - -# This is the list of types that we need to recursively check its values -PYTHON_NESTED_TYPE_CHECK_LIST = ["array", "tuple"] - - -NESTED_CONVERSION_TYPE_LIST = ["Array", "ArrayList", "array"] - - -#### Helper functions for AST #### -def find_description(func_descriptions, name): - if type(func_descriptions) == list: - for func_description in func_descriptions: - if func_description["name"] == name: - return func_description - return None - else: - # it is a dict, there is only one function - return func_descriptions - - -def get_possible_answer_type(possible_answer: list): - for answer in possible_answer: - if answer != "": # Optional parameter - return type(answer) - return None - - -def type_checker( - param: str, - value, - possible_answer: list, - expected_type_description: str, - expected_type_converted, - nested_type_converted, -): - # NOTE: This type checker only supports nested type checking for one level deep. - # We didn't implement recursive type checking for nested types, as it's not needed for the current use case and it's very complex. - - result: Any = { - "valid": True, - "error": [], - "is_variable": False, - "error_type": "type_error:simple", - } - - is_variable = False - # check for the case where a variable is used instead of a actual value. - # use the type in possible_answer as the expected type - possible_answer_type = get_possible_answer_type(possible_answer) - # if possible_answer only contains optional parameters, we can't determine the type - if possible_answer_type != None: - # we are being precise here. - # in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer - if possible_answer_type != expected_type_converted: - is_variable = True - - # value is the same type as in function description - if type(value) == expected_type_converted: - # We don't need to do recursive check for simple types - if nested_type_converted == None: - result["is_variable"] = is_variable - return result - else: - for possible_answer_item in possible_answer: - flag = True # Each parameter should match to at least one possible answer type. - # Here, we assume that each item should be the same type. We could also relax it. - if type(possible_answer_item) == list: - for value_item in value: - checker_result = type_checker( - param, - value_item, - possible_answer_item, - str(nested_type_converted), - nested_type_converted, - None, - ) - if not checker_result["valid"]: - flag = False - break - - if flag: - return {"valid": True, "error": [], "is_variable": is_variable} - - result["valid"] = False - result["error"] = [ - f"Nested type checking failed for parameter {repr(param)}. Expected outer type {expected_type_description} with inner type {str(nested_type_converted)}. Parameter value: {repr(value)}." - ] - result["error_type"] = "type_error:nested" - - # value is not as expected, check for the case where a variable is used instead of a actual value - # use the type in possible_answer as the expected type - possible_answer_type = get_possible_answer_type(possible_answer) - # if possible_answer only contains optional parameters, we can't determine the type - if possible_answer_type != None: - # we are being precise here. - # in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer - if type(value) == possible_answer_type: - result["is_variable"] = True - return result - - result["valid"] = False - result["error"].append( - f"Incorrect type for parameter {repr(param)}. Expected type {expected_type_description}, got {type(value).__name__}. Parameter value: {repr(value)}." - ) - result["error_type"] = "type_error:simple" - return result - - -def standardize_string(input_string: str): - # This function standardizes the string by removing all the spaces, ",./-_*^" punctuation, and converting it to lowercase - # It will also convert all the single quotes to double quotes - # This is used to compare the model output with the possible answers - # We don't want to punish model for answer like April 1, 2024 vs April 1,2024, vs April 1 2024 - regex_string = r"[ \,\.\/\-\_\*\^]" - return re.sub(regex_string, "", input_string).lower().replace("'", '"') - - -def string_checker(param: str, model_output: str, possible_answer: list): - standardize_possible_answer = [] - standardize_model_output = standardize_string(model_output) - for i in range(len(possible_answer)): - if type(possible_answer[i]) == str: - standardize_possible_answer.append(standardize_string(possible_answer[i])) - - if standardize_model_output not in standardize_possible_answer: - return { - "valid": False, - "error": [ - f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}. Case insensitive." - ], - "error_type": "value_error:string", - } - - return {"valid": True, "error": []} - - -def list_checker(param: str, model_output: list, possible_answer: list): - # Convert the tuple to a list - - standardize_model_output = list(model_output) - - # If the element in the list is a string, we need to standardize it - for i in range(len(standardize_model_output)): - if type(standardize_model_output[i]) == str: - standardize_model_output[i] = standardize_string(model_output[i]) - - standardize_possible_answer: Any = [] - # We also need to standardize the possible answers - for i in range(len(possible_answer)): - standardize_possible_answer.append([]) - for j in range(len(possible_answer[i])): - if type(possible_answer[i][j]) == str: - standardize_possible_answer[i].append(standardize_string(possible_answer[i][j])) - else: - standardize_possible_answer[i].append(possible_answer[i][j]) - - if standardize_model_output not in standardize_possible_answer: - return { - "valid": False, - "error": [ - f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}." - ], - "error_type": "value_error:list/tuple", - } - - return {"valid": True, "error": []} - - -def dict_checker(param: str, model_output: dict, possible_answers: list): - # This function works for simple dictionaries, but not dictionaries with nested dictionaries. - # The current dataset only contains simple dictionaries, so this is sufficient. - - result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"} - for i in range(len(possible_answers)): - if possible_answers[i] == "": - continue - - result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"} - - flag = True - - possible_answer = possible_answers[i] - # possible_anwer is a single dictionary - - for key, value in model_output.items(): - if key not in possible_answer: - result["valid"] = False - result["error"].append(f"Unexpected dict key parameter: '{key}'.") # type: ignore[attr-defined] - result["error_type"] = "value_error:dict_key" - flag = False - break - - standardize_value = value - # If the value is a string, we need to standardize it - if type(value) == str: - standardize_value = standardize_string(value) - - # We also need to standardize the possible answers if they are string - standardize_possible_answer = [] - for i in range(len(possible_answer[key])): - if type(possible_answer[key][i]) == str: - standardize_possible_answer.append(standardize_string(possible_answer[key][i])) - else: - standardize_possible_answer.append(possible_answer[key][i]) - - if standardize_value not in standardize_possible_answer: - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Invalid value for parameter {repr(key)}: {repr(value)}. Expected one of {standardize_possible_answer}." - ) - result["error_type"] = "value_error:dict_value" - flag = False - break - - for key, value in possible_answer.items(): - if key not in model_output and "" not in value: - result["valid"] = False - result["error"].append(f"Missing dict key parameter: '{key}'.") # type: ignore[attr-defined] - result["error_type"] = "value_error:dict_key" - flag = False - break - - if flag: - return {"valid": True, "error": []} - - return result - - -def list_dict_checker(param: str, model_output: list, possible_answers: list): - # This function takes in a list of dictionaries and checks if each dictionary is valid - # The order of the dictionaries in the list must match the order of the possible answers - - result = {"valid": False, "error": [], "error_type": "list_dict_checker:unclear"} - - for answer_index in range(len(possible_answers)): - flag = True # True means so far, all dictionaries are valid - - # Only proceed if the number of dictionaries in the list matches the number of dictionaries in the possible answers - if len(model_output) != len(possible_answers[answer_index]): - result["valid"] = False - result["error"] = ["Wrong number of dictionaries in the list."] - result["error_type"] = "value_error:list_dict_count" - flag = False - continue - - for dict_index in range(len(model_output)): - result = dict_checker( - param, - model_output[dict_index], - [possible_answers[answer_index][dict_index]], - ) - if not result["valid"]: - flag = False - break - if flag: - return {"valid": True, "error": []} - - return result - - -def simple_function_checker( - func_description: dict, - model_output: dict, - possible_answer: dict, - language: str, - model_name: str, -): - possible_answer = list(possible_answer.values())[0] - # Extract function name and parameters details - func_name = func_description["name"] - param_details = func_description["parameters"]["properties"] - required_params = func_description["parameters"]["required"] - - # Initialize a result dictionary - result = { - "valid": True, - "error": [], - "error_type": "simple_function_checker:unclear", - } - - # Check if function name matches - if func_name not in model_output: - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Function name {repr(func_name)} not found in model output." - ) - result["error_type"] = "simple_function_checker:wrong_func_name" - return result - - model_params = model_output[func_name] - - # Check for required parameters in model output - for param in required_params: - if param not in model_params: - result["valid"] = False - result["error"].append(f"Missing required parameter: {repr(param)}.") # type: ignore[attr-defined] - result["error_type"] = "simple_function_checker:missing_required" - return result - - # Validate types and values for each parameter in model output - for param, value in model_params.items(): - if param not in param_details or param not in possible_answer: - result["valid"] = False - result["error"].append(f"Unexpected parameter: {repr(param)}.") # type: ignore[attr-defined] - result["error_type"] = "simple_function_checker:unexpected_param" - return result - - full_param_details = param_details[param] - expected_type_description = full_param_details["type"] # This is a string - is_variable = False - nested_type_converted = None - - if language == "Java": - from evals.utils.bfcl.java_type_converter import java_type_converter - - expected_type_converted = JAVA_TYPE_CONVERSION[expected_type_description] - - if expected_type_description in JAVA_TYPE_CONVERSION: - if type(value) != str: - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}." - ) - result["error_type"] = "type_error:java" - return result - - if expected_type_description in NESTED_CONVERSION_TYPE_LIST: - nested_type = param_details[param]["items"]["type"] - nested_type_converted = JAVA_TYPE_CONVERSION[nested_type] - value = java_type_converter(value, expected_type_description, nested_type) - else: - value = java_type_converter(value, expected_type_description) - - elif language == "JavaScript": - from evals.utils.bfcl.js_type_converter import js_type_converter - - expected_type_converted = JS_TYPE_CONVERSION[expected_type_description] - - if expected_type_description in JS_TYPE_CONVERSION: - if type(value) != str: - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}." - ) - result["error_type"] = "type_error:js" - return result - - if expected_type_description in NESTED_CONVERSION_TYPE_LIST: - nested_type = param_details[param]["items"]["type"] - nested_type_converted = JS_TYPE_CONVERSION[nested_type] - value = js_type_converter(value, expected_type_description, nested_type) - else: - value = js_type_converter(value, expected_type_description) - - elif language == "Python": - expected_type_converted = PYTHON_TYPE_MAPPING[expected_type_description] - if expected_type_description in PYTHON_NESTED_TYPE_CHECK_LIST: - nested_type = param_details[param]["items"]["type"] - nested_type_converted = PYTHON_TYPE_MAPPING[nested_type] - - # We convert all tuple value to list when the expected type is tuple. - # The conversion is necessary because any tuple in the possible answer would become a list after being processed through json.dump() and json.load(). - # This does introduce some false positive (eg, when the model provides a list value instead of tuple). We hope to find a better solution in the future. - if expected_type_description == "tuple" and type(value) == tuple: - value = list(value) - - # Allow python auto conversion from int to float - if language == "Python" and expected_type_description == "float" and type(value) == int: - value = float(value) - - # Type checking - # In fact, we only check for Python here. - # Type check for other languages are handled by the type converter, and so their value (after conversion) is always correct. - type_check_result = type_checker( - param, - value, - possible_answer[param], - expected_type_description, - expected_type_converted, - nested_type_converted, - ) - is_variable = type_check_result["is_variable"] - if not type_check_result["valid"]: - return type_check_result - - # It doesn't make sense to special handle dictionaries and list of dictionaries if the value is a variable. - # We can just treat the variable as a string and use the normal flow. - if not is_variable: - # Special handle for dictionaries - if expected_type_converted == dict: - result = dict_checker(param, value, possible_answer[param]) - if not result["valid"]: - return result - continue - - # Special handle for list of dictionaries - elif expected_type_converted == list and nested_type_converted == dict: - result = list_dict_checker(param, value, possible_answer[param]) - if not result["valid"]: - return result - continue - - # Special handle for strings - elif expected_type_converted == str: - # We don't check for case sensitivity for string, as long as it's not a variable - result = string_checker(param, value, possible_answer[param]) - if not result["valid"]: - return result - continue - - elif expected_type_converted == list: - result = list_checker(param, value, possible_answer[param]) - if not result["valid"]: - return result - continue - - # Check if the value is within the possible answers - if value not in possible_answer[param]: - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Invalid value for parameter {repr(param)}: {repr(value)}. Expected one of {possible_answer[param]}." - ) - result["error_type"] = "value_error:others" - return result - - # Check for optional parameters not provided but allowed - for param in possible_answer: - if param not in model_params and "" not in possible_answer[param]: - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Optional parameter {repr(param)} not provided and not marked as optional." - ) - result["error_type"] = "simple_function_checker:missing_optional" - return result - - return result - - -def parallel_function_checker_enforce_order( - func_descriptions: list, - model_output: list, - possible_answers: dict, - language: str, - model_name: str, -): - if len(model_output) != len(possible_answers): - return { - "valid": False, - "error": ["Wrong number of functions."], - "error_type": "parallel_function_checker_enforce_order:wrong_count", - } - - func_name_list = list(possible_answers.keys()) - possible_answers_list = [] - - for key, value in possible_answers.items(): - possible_answers_list.append({key: value}) - - for i in range(len(possible_answers_list)): - func_description = find_description(func_descriptions, func_name_list[i]) - - result = simple_function_checker( - func_description, - model_output[i], - possible_answers_list[i], - language, - model_name, - ) - if not result["valid"]: - return result - - return {"valid": True, "error": []} - - -def parallel_function_checker_no_order( - func_descriptions: list, - model_output: list, - possible_answers: list, - language: str, - model_name: str, -): - if len(model_output) != len(possible_answers): - return { - "valid": False, - "error": ["Wrong number of functions."], - "error_type": "parallel_function_checker_no_order:wrong_count", - } - - matched_indices = [] - - # We go throught the possible answers one by one, and eliminate the model output that matches the possible answer - # It must be this way because we need ground truth to fetch the correct function description - for i in range(len(possible_answers)): - # possible_answers[i] is a dictionary with only one key - func_name_expected = list(possible_answers[i].keys())[0] - func_description = find_description(func_descriptions, func_name_expected) - - all_errors = [] - - for index in range(len(model_output)): - if index in matched_indices: - continue - - result = simple_function_checker( - func_description, - model_output[index], - possible_answers[i], - language, - model_name, - ) - - if result["valid"]: - matched_indices.append(index) - break - else: - all_errors.append( - { - f"Model Result Index {index}": { - "sub_error": result["error"], - "sub_error_type": result["error_type"], - "model_output_item": model_output[index], - "possible_answer_item": possible_answers[i], - } - } - ) - - if not result["valid"]: - considered_indices = [i for i in range(len(model_output)) if i not in matched_indices] - all_errors.insert( - 0, - f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type] - ) - return { - "valid": False, - "error": all_errors, - "error_type": "parallel_function_checker_no_order:cannot_find_match", - } - - return {"valid": True, "error": []} - - -def multiple_function_checker( - func_descriptions: list, - model_output: list, - possible_answers: list, - language: str, - model_name: str, -): - if len(model_output) != len(possible_answers): - return { - "valid": False, - "error": ["Wrong number of functions."], - "error_type": "multiple_function_checker:wrong_count", - } - - # possible_answers is a list of only one dictionary with only one key - func_name_expected = list(possible_answers[0].keys())[0] - func_description = find_description(func_descriptions, func_name_expected) - return simple_function_checker( - func_description, - model_output[0], - possible_answers[0], - language, - model_name, - ) - - -def patten_matcher(exec_output, expected_result, function_call, is_sanity_check): - result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"} - - if type(exec_output) != type(expected_result): - return { - "valid": False, - "error": [ - f"Wrong execution result type for {repr(function_call)}. Expected type: {type(expected_result)}, but got: {type(exec_output)}." - ], - "error_type": "executable_checker:wrong_result_type", - "model_executed_output": exec_output, - } - if type(exec_output) == dict: - # We loose the requirement for the sanity check as the expected result used in the sanity check might not be the most up-to-date one. - # This happens when the key is a timestamp or a random number. - if is_sanity_check: - if len(exec_output) != len(expected_result): - return { - "valid": False, - "error": [ - f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}." - ], - "error_type": "executable_checker:wrong_result_type:dict_length", - "model_executed_output": exec_output, - } - else: - return result - - for key, value in expected_result.items(): - if key not in exec_output: - return { - "valid": False, - "error": [ - f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not found in the model output." - ], - "error_type": "executable_checker:wrong_result_type:dict_key_not_found", - "model_executed_output": exec_output, - } - for key, value in exec_output.items(): - if key not in expected_result: - return { - "valid": False, - "error": [ - f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not expected in the model output." - ], - "error_type": "executable_checker:wrong_result_type:dict_extra_key", - "model_executed_output": exec_output, - } - if type(exec_output) == list: - if len(exec_output) != len(expected_result): - return { - "valid": False, - "error": [ - f"Wrong execution result pattern for {repr(function_call)}. Expect type list, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}." - ], - "error_type": "executable_checker:wrong_result_type:list_length", - "model_executed_output": exec_output, - } - return result - - -#### Helper functions for Exec #### -def executable_checker_simple( - function_call: str, - expected_result, - expected_result_type: str, - is_sanity_check=False, -): - result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"} - - exec_dict: Any = {} - - try: - exec( - "from executable_python_function import *" + "\nresult=" + function_call, - exec_dict, - ) - exec_output = exec_dict["result"] - except NoAPIKeyError as e: - raise e - except Exception as e: - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Error in execution: {repr(function_call)}. Error: {str(e)}" - ) - result["error_type"] = "executable_checker:execution_error" - return result - - # We need to special handle the case where the execution result is a tuple and convert it to a list - # Because when json is stored, the tuple is converted to a list, and so the expected result is a list when loaded from json - if isinstance(exec_output, tuple): - exec_output = list(exec_output) - - if expected_result_type == "exact_match": - if exec_output != expected_result: - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}." - ) - result["error_type"] = "executable_checker:wrong_result" - result["model_executed_output"] = exec_output - return result - - elif expected_result_type == "real_time_match": - # Allow for 5% difference - if (type(expected_result) == float or type(expected_result) == int) and ( - type(exec_output) == float or type(exec_output) == int - ): - if not ( - expected_result * (1 - REAL_TIME_MATCH_ALLOWED_DIFFERENCE) - <= exec_output - <= expected_result * (1 + REAL_TIME_MATCH_ALLOWED_DIFFERENCE) - ): - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. {REAL_TIME_MATCH_ALLOWED_DIFFERENCE * 100}% difference allowed." - ) - result["error_type"] = "executable_checker:wrong_result_real_time" - result["model_executed_output"] = exec_output - return result - else: - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. Type needs to be float or int for real time match criteria." - ) - result["error_type"] = "executable_checker:wrong_result_real_time" - result["model_executed_output"] = exec_output - return result - - else: - # structural match - pattern_match_result = patten_matcher(exec_output, expected_result, function_call, is_sanity_check) - if not pattern_match_result["valid"]: - return pattern_match_result - - return result - - -def executable_checker_parallel_no_order( - decoded_result: list, expected_exec_result: list, expected_exec_result_type: list -): - if len(decoded_result) != len(expected_exec_result): - return { - "valid": False, - "error": [ - f"Wrong number of functions provided. Expected {len(expected_exec_result)}, but got {len(decoded_result)}." - ], - "error_type": "value_error:exec_result_count", - } - - matched_indices = [] - for i in range(len(expected_exec_result)): - all_errors = [] - for index in range(len(decoded_result)): - if index in matched_indices: - continue - - result = executable_checker_simple( - decoded_result[index], - expected_exec_result[i], - expected_exec_result_type[i], - False, - ) - - if result["valid"]: - matched_indices.append(index) - break - else: - all_errors.append( - { - f"Model Result Index {index}": { - "sub_error": result["error"], - "sub_error_type": result["error_type"], - "model_executed_output": ( - result["model_executed_output"] if "model_executed_output" in result else None - ), - } - } - ) - - if not result["valid"]: - considered_indices = [i for i in range(len(decoded_result)) if i not in matched_indices] - all_errors.insert( - 0, - f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type] - ) - return { - "valid": False, - "error": all_errors, - "error_type": "executable_checker:cannot_find_match", - } - - return {"valid": True, "error": [], "error_type": "executable_checker:unclear"} - - -#### Main function #### -def executable_checker_rest(func_call, idx): - # Move this here for now to avoid needing to read this file / fix paths to be relative to dataset_dir. Fix when it's actually needed / used. - EVAL_GROUND_TRUTH_PATH = "/mnt/wsfuse/fair_llm_v2/datasets/eval/bfcl/rest-eval-response_v5.jsonl" # Ground truth file for v5 for rest execution - with open(EVAL_GROUND_TRUTH_PATH, "r") as f: - EVAL_GROUND_TRUTH = f.readlines() - if "https://geocode.maps.co" in func_call: - time.sleep(2) - if "requests_get" in func_call: - func_call = func_call.replace("requests_get", "requests.get") - try: - response = eval(func_call) - except Exception as e: - return { - "valid": False, - "error": [f"Execution failed. {str(e)}"], - "error_type": "executable_checker_rest:execution_error", - } - - try: - if response.status_code == 200: - eval_GT_json = json.loads(EVAL_GROUND_TRUTH[idx]) - try: - if isinstance(eval_GT_json, dict): - if isinstance(response.json(), dict): - if set(eval_GT_json.keys()) == set(response.json().keys()): - return {"valid": True, "error": [], "error_type": ""} - return { - "valid": False, - "error": ["Key inconsistency"], - "error_type": "executable_checker_rest:wrong_key", - } - return { - "valid": False, - "error": [f"Expected dictionary, but got {type(response.json())}"], - "error_type": "executable_checker_rest:wrong_type", - } - - elif isinstance(eval_GT_json, list): - if isinstance(response.json(), list): - if len(eval_GT_json) != len(response.json()): - return { - "valid": False, - "error": [f"Response list length inconsistency."], - "error_type": "value_error:exec_result_rest_count", - } - - else: - for i in range(len(eval_GT_json)): - if set(eval_GT_json[i].keys()) != set(response.json()[i].keys()): - return { - "valid": False, - "error": [f"Key inconsistency"], - "error_type": "executable_checker_rest:wrong_key", - } - - return {"valid": True, "error": []} - else: - return { - "valid": False, - "error": [f"Expected list, but got {type(response.json())}"], - "error_type": "executable_checker_rest:wrong_type", - } - return { - "valid": False, - "error": [f"Expected dict or list, but got {type(response.json())}"], - "error_type": "executable_checker_rest:wrong_type", - } - except Exception as e: - return { - "valid": False, - "error": [ - f"Error in execution and type checking. Status code: {response.status_code}. Error: {str(e)}" - ], - "error_type": "executable_checker_rest:response_format_error", - } - else: - return { - "valid": False, - "error": [f"Execution result status code is not 200, got {response.status_code}"], - "error_type": "executable_checker_rest:wrong_status_code", - } - except Exception as e: - return { - "valid": False, - "error": [f"Cannot get status code of the response. Error: {str(e)}"], - "error_type": "executable_checker_rest:cannot_get_status_code", - } - - -def ast_checker(func_description, model_output, possible_answer, language, test_category, model_name): - if "parallel" in test_category: - return parallel_function_checker_no_order(func_description, model_output, possible_answer, language, model_name) - - elif "multiple" in test_category: - return multiple_function_checker(func_description, model_output, possible_answer, language, model_name) - - else: - if len(model_output) != 1: - return { - "valid": False, - "error": ["Wrong number of functions."], - "error_type": "simple_function_checker:wrong_count", - } - - return simple_function_checker( - func_description[0], - model_output[0], - possible_answer[0], - language, - model_name, - ) - - -def exec_checker(decoded_result: list, func_description: dict, test_category: str): - if "multiple" in test_category or "parallel" in test_category: - return executable_checker_parallel_no_order( - decoded_result, - func_description["execution_result"], - func_description["execution_result_type"], - ) - - else: - if len(decoded_result) != 1: - return { - "valid": False, - "error": ["Wrong number of functions."], - "error_type": "simple_exec_checker:wrong_count", - } - return executable_checker_simple( - decoded_result[0], - func_description["execution_result"][0], - func_description["execution_result_type"][0], - False, - ) - - -def is_empty_output(decoded_output): - # This function is a patch to the ast decoder for relevance detection - # Sometimes the ast decoder will parse successfully, but the input doens't really have a function call - # [], [{}], and anything that is not in function calling format is considered empty (and thus should be marked as correct) - if not is_function_calling_format_output(decoded_output): - return True - if len(decoded_output) == 0: - return True - if len(decoded_output) == 1 and len(decoded_output[0]) == 0: - return True - - -def is_function_calling_format_output(decoded_output): - # Ensure the output is a list of dictionaries - if type(decoded_output) == list: - for item in decoded_output: - if type(item) != dict: - return False - return True - return False diff --git a/llama_stack/providers/inline/scoring/basic/utils/bfcl/tree_sitter.py b/llama_stack/providers/inline/scoring/basic/utils/bfcl/tree_sitter.py deleted file mode 100644 index ed97ee360..000000000 --- a/llama_stack/providers/inline/scoring/basic/utils/bfcl/tree_sitter.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -""" -Tree-sitter changes its API with unfortunate frequency. Modules that need it should -import it from here so that we can centrally manage things as necessary. -""" - -# These currently work with tree-sitter 0.23.0 -# NOTE: Don't import tree-sitter or any of the language modules in the main module -# because not all environments have them. Import lazily inside functions where needed. - -import importlib -import typing - -if typing.TYPE_CHECKING: - import tree_sitter - - -def get_language(language: str) -> "tree_sitter.Language": - import tree_sitter - - language_module_name = f"tree_sitter_{language}" - try: - language_module = importlib.import_module(language_module_name) - except ModuleNotFoundError as exc: - raise ValueError( - f"Language {language} is not found. Please install the tree-sitter-{language} package." - ) from exc - return tree_sitter.Language(language_module.language()) - - -def get_parser(language: str, **kwargs) -> "tree_sitter.Parser": - import tree_sitter - - lang = get_language(language) - return tree_sitter.Parser(lang, **kwargs) diff --git a/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py b/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py deleted file mode 100644 index 28605159f..000000000 --- a/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py +++ /dev/null @@ -1,3319 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import collections -import functools -import json -import logging -import random -import re -import string -from types import MappingProxyType -from typing import Dict, Iterable, List, Optional, Sequence, Union - -import emoji -import langdetect -import nltk -from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai -from pythainlp.tokenize import word_tokenize as word_tokenize_thai - -logger = logging.getLogger() - -WORD_LIST = [ - "western", - "sentence", - "signal", - "dump", - "spot", - "opposite", - "bottom", - "potato", - "administration", - "working", - "welcome", - "morning", - "good", - "agency", - "primary", - "wish", - "responsibility", - "press", - "problem", - "president", - "steal", - "brush", - "read", - "type", - "beat", - "trainer", - "growth", - "lock", - "bone", - "case", - "equal", - "comfortable", - "region", - "replacement", - "performance", - "mate", - "walk", - "medicine", - "film", - "thing", - "rock", - "tap", - "total", - "competition", - "ease", - "south", - "establishment", - "gather", - "parking", - "world", - "plenty", - "breath", - "claim", - "alcohol", - "trade", - "dear", - "highlight", - "street", - "matter", - "decision", - "mess", - "agreement", - "studio", - "coach", - "assist", - "brain", - "wing", - "style", - "private", - "top", - "brown", - "leg", - "buy", - "procedure", - "method", - "speed", - "high", - "company", - "valuable", - "pie", - "analyst", - "session", - "pattern", - "district", - "pleasure", - "dinner", - "swimming", - "joke", - "order", - "plate", - "department", - "motor", - "cell", - "spend", - "cabinet", - "difference", - "power", - "examination", - "engine", - "horse", - "dimension", - "pay", - "toe", - "curve", - "literature", - "bother", - "fire", - "possibility", - "debate", - "activity", - "passage", - "hello", - "cycle", - "background", - "quiet", - "author", - "effect", - "actor", - "page", - "bicycle", - "error", - "throat", - "attack", - "character", - "phone", - "tea", - "increase", - "outcome", - "file", - "specific", - "inspector", - "internal", - "potential", - "staff", - "building", - "employer", - "shoe", - "hand", - "direction", - "garden", - "purchase", - "interview", - "study", - "recognition", - "member", - "spiritual", - "oven", - "sandwich", - "weird", - "passenger", - "particular", - "response", - "reaction", - "size", - "variation", - "a", - "cancel", - "candy", - "exit", - "guest", - "condition", - "fly", - "price", - "weakness", - "convert", - "hotel", - "great", - "mouth", - "mind", - "song", - "sugar", - "suspect", - "telephone", - "ear", - "roof", - "paint", - "refrigerator", - "organization", - "jury", - "reward", - "engineering", - "day", - "possession", - "crew", - "bar", - "road", - "description", - "celebration", - "score", - "mark", - "letter", - "shower", - "suggestion", - "sir", - "luck", - "national", - "progress", - "hall", - "stroke", - "theory", - "offer", - "story", - "tax", - "definition", - "history", - "ride", - "medium", - "opening", - "glass", - "elevator", - "stomach", - "question", - "ability", - "leading", - "village", - "computer", - "city", - "grand", - "confidence", - "candle", - "priest", - "recommendation", - "point", - "necessary", - "body", - "desk", - "secret", - "horror", - "noise", - "culture", - "warning", - "water", - "round", - "diet", - "flower", - "bus", - "tough", - "permission", - "week", - "prompt", - "connection", - "abuse", - "height", - "save", - "corner", - "border", - "stress", - "drive", - "stop", - "rip", - "meal", - "listen", - "confusion", - "girlfriend", - "living", - "relation", - "significance", - "plan", - "creative", - "atmosphere", - "blame", - "invite", - "housing", - "paper", - "drink", - "roll", - "silver", - "drunk", - "age", - "damage", - "smoke", - "environment", - "pack", - "savings", - "influence", - "tourist", - "rain", - "post", - "sign", - "grandmother", - "run", - "profit", - "push", - "clerk", - "final", - "wine", - "swim", - "pause", - "stuff", - "singer", - "funeral", - "average", - "source", - "scene", - "tradition", - "personal", - "snow", - "nobody", - "distance", - "sort", - "sensitive", - "animal", - "major", - "negotiation", - "click", - "mood", - "period", - "arrival", - "expression", - "holiday", - "repeat", - "dust", - "closet", - "gold", - "bad", - "sail", - "combination", - "clothes", - "emphasis", - "duty", - "black", - "step", - "school", - "jump", - "document", - "professional", - "lip", - "chemical", - "front", - "wake", - "while", - "inside", - "watch", - "row", - "subject", - "penalty", - "balance", - "possible", - "adult", - "aside", - "sample", - "appeal", - "wedding", - "depth", - "king", - "award", - "wife", - "blow", - "site", - "camp", - "music", - "safe", - "gift", - "fault", - "guess", - "act", - "shame", - "drama", - "capital", - "exam", - "stupid", - "record", - "sound", - "swing", - "novel", - "minimum", - "ratio", - "machine", - "shape", - "lead", - "operation", - "salary", - "cloud", - "affair", - "hit", - "chapter", - "stage", - "quantity", - "access", - "army", - "chain", - "traffic", - "kick", - "analysis", - "airport", - "time", - "vacation", - "philosophy", - "ball", - "chest", - "thanks", - "place", - "mountain", - "advertising", - "red", - "past", - "rent", - "return", - "tour", - "house", - "construction", - "net", - "native", - "war", - "figure", - "fee", - "spray", - "user", - "dirt", - "shot", - "task", - "stick", - "friend", - "software", - "promotion", - "interaction", - "surround", - "block", - "purpose", - "practice", - "conflict", - "routine", - "requirement", - "bonus", - "hole", - "state", - "junior", - "sweet", - "catch", - "tear", - "fold", - "wall", - "editor", - "life", - "position", - "pound", - "respect", - "bathroom", - "coat", - "script", - "job", - "teach", - "birth", - "view", - "resolve", - "theme", - "employee", - "doubt", - "market", - "education", - "serve", - "recover", - "tone", - "harm", - "miss", - "union", - "understanding", - "cow", - "river", - "association", - "concept", - "training", - "recipe", - "relationship", - "reserve", - "depression", - "proof", - "hair", - "revenue", - "independent", - "lift", - "assignment", - "temporary", - "amount", - "loss", - "edge", - "track", - "check", - "rope", - "estimate", - "pollution", - "stable", - "message", - "delivery", - "perspective", - "mirror", - "assistant", - "representative", - "witness", - "nature", - "judge", - "fruit", - "tip", - "devil", - "town", - "emergency", - "upper", - "drop", - "stay", - "human", - "neck", - "speaker", - "network", - "sing", - "resist", - "league", - "trip", - "signature", - "lawyer", - "importance", - "gas", - "choice", - "engineer", - "success", - "part", - "external", - "worker", - "simple", - "quarter", - "student", - "heart", - "pass", - "spite", - "shift", - "rough", - "lady", - "grass", - "community", - "garage", - "youth", - "standard", - "skirt", - "promise", - "blind", - "television", - "disease", - "commission", - "positive", - "energy", - "calm", - "presence", - "tune", - "basis", - "preference", - "head", - "common", - "cut", - "somewhere", - "presentation", - "current", - "thought", - "revolution", - "effort", - "master", - "implement", - "republic", - "floor", - "principle", - "stranger", - "shoulder", - "grade", - "button", - "tennis", - "police", - "collection", - "account", - "register", - "glove", - "divide", - "professor", - "chair", - "priority", - "combine", - "peace", - "extension", - "maybe", - "evening", - "frame", - "sister", - "wave", - "code", - "application", - "mouse", - "match", - "counter", - "bottle", - "half", - "cheek", - "resolution", - "back", - "knowledge", - "make", - "discussion", - "screw", - "length", - "accident", - "battle", - "dress", - "knee", - "log", - "package", - "it", - "turn", - "hearing", - "newspaper", - "layer", - "wealth", - "profile", - "imagination", - "answer", - "weekend", - "teacher", - "appearance", - "meet", - "bike", - "rise", - "belt", - "crash", - "bowl", - "equivalent", - "support", - "image", - "poem", - "risk", - "excitement", - "remote", - "secretary", - "public", - "produce", - "plane", - "display", - "money", - "sand", - "situation", - "punch", - "customer", - "title", - "shake", - "mortgage", - "option", - "number", - "pop", - "window", - "extent", - "nothing", - "experience", - "opinion", - "departure", - "dance", - "indication", - "boy", - "material", - "band", - "leader", - "sun", - "beautiful", - "muscle", - "farmer", - "variety", - "fat", - "handle", - "director", - "opportunity", - "calendar", - "outside", - "pace", - "bath", - "fish", - "consequence", - "put", - "owner", - "go", - "doctor", - "information", - "share", - "hurt", - "protection", - "career", - "finance", - "force", - "golf", - "garbage", - "aspect", - "kid", - "food", - "boot", - "milk", - "respond", - "objective", - "reality", - "raw", - "ring", - "mall", - "one", - "impact", - "area", - "news", - "international", - "series", - "impress", - "mother", - "shelter", - "strike", - "loan", - "month", - "seat", - "anything", - "entertainment", - "familiar", - "clue", - "year", - "glad", - "supermarket", - "natural", - "god", - "cost", - "conversation", - "tie", - "ruin", - "comfort", - "earth", - "storm", - "percentage", - "assistance", - "budget", - "strength", - "beginning", - "sleep", - "other", - "young", - "unit", - "fill", - "store", - "desire", - "hide", - "value", - "cup", - "maintenance", - "nurse", - "function", - "tower", - "role", - "class", - "camera", - "database", - "panic", - "nation", - "basket", - "ice", - "art", - "spirit", - "chart", - "exchange", - "feedback", - "statement", - "reputation", - "search", - "hunt", - "exercise", - "nasty", - "notice", - "male", - "yard", - "annual", - "collar", - "date", - "platform", - "plant", - "fortune", - "passion", - "friendship", - "spread", - "cancer", - "ticket", - "attitude", - "island", - "active", - "object", - "service", - "buyer", - "bite", - "card", - "face", - "steak", - "proposal", - "patient", - "heat", - "rule", - "resident", - "broad", - "politics", - "west", - "knife", - "expert", - "girl", - "design", - "salt", - "baseball", - "grab", - "inspection", - "cousin", - "couple", - "magazine", - "cook", - "dependent", - "security", - "chicken", - "version", - "currency", - "ladder", - "scheme", - "kitchen", - "employment", - "local", - "attention", - "manager", - "fact", - "cover", - "sad", - "guard", - "relative", - "county", - "rate", - "lunch", - "program", - "initiative", - "gear", - "bridge", - "breast", - "talk", - "dish", - "guarantee", - "beer", - "vehicle", - "reception", - "woman", - "substance", - "copy", - "lecture", - "advantage", - "park", - "cold", - "death", - "mix", - "hold", - "scale", - "tomorrow", - "blood", - "request", - "green", - "cookie", - "church", - "strip", - "forever", - "beyond", - "debt", - "tackle", - "wash", - "following", - "feel", - "maximum", - "sector", - "sea", - "property", - "economics", - "menu", - "bench", - "try", - "language", - "start", - "call", - "solid", - "address", - "income", - "foot", - "senior", - "honey", - "few", - "mixture", - "cash", - "grocery", - "link", - "map", - "form", - "factor", - "pot", - "model", - "writer", - "farm", - "winter", - "skill", - "anywhere", - "birthday", - "policy", - "release", - "husband", - "lab", - "hurry", - "mail", - "equipment", - "sink", - "pair", - "driver", - "consideration", - "leather", - "skin", - "blue", - "boat", - "sale", - "brick", - "two", - "feed", - "square", - "dot", - "rush", - "dream", - "location", - "afternoon", - "manufacturer", - "control", - "occasion", - "trouble", - "introduction", - "advice", - "bet", - "eat", - "kill", - "category", - "manner", - "office", - "estate", - "pride", - "awareness", - "slip", - "crack", - "client", - "nail", - "shoot", - "membership", - "soft", - "anybody", - "web", - "official", - "individual", - "pizza", - "interest", - "bag", - "spell", - "profession", - "queen", - "deal", - "resource", - "ship", - "guy", - "chocolate", - "joint", - "formal", - "upstairs", - "car", - "resort", - "abroad", - "dealer", - "associate", - "finger", - "surgery", - "comment", - "team", - "detail", - "crazy", - "path", - "tale", - "initial", - "arm", - "radio", - "demand", - "single", - "draw", - "yellow", - "contest", - "piece", - "quote", - "pull", - "commercial", - "shirt", - "contribution", - "cream", - "channel", - "suit", - "discipline", - "instruction", - "concert", - "speech", - "low", - "effective", - "hang", - "scratch", - "industry", - "breakfast", - "lay", - "join", - "metal", - "bedroom", - "minute", - "product", - "rest", - "temperature", - "many", - "give", - "argument", - "print", - "purple", - "laugh", - "health", - "credit", - "investment", - "sell", - "setting", - "lesson", - "egg", - "middle", - "marriage", - "level", - "evidence", - "phrase", - "love", - "self", - "benefit", - "guidance", - "affect", - "you", - "dad", - "anxiety", - "special", - "boyfriend", - "test", - "blank", - "payment", - "soup", - "obligation", - "reply", - "smile", - "deep", - "complaint", - "addition", - "review", - "box", - "towel", - "minor", - "fun", - "soil", - "issue", - "cigarette", - "internet", - "gain", - "tell", - "entry", - "spare", - "incident", - "family", - "refuse", - "branch", - "can", - "pen", - "grandfather", - "constant", - "tank", - "uncle", - "climate", - "ground", - "volume", - "communication", - "kind", - "poet", - "child", - "screen", - "mine", - "quit", - "gene", - "lack", - "charity", - "memory", - "tooth", - "fear", - "mention", - "marketing", - "reveal", - "reason", - "court", - "season", - "freedom", - "land", - "sport", - "audience", - "classroom", - "law", - "hook", - "win", - "carry", - "eye", - "smell", - "distribution", - "research", - "country", - "dare", - "hope", - "whereas", - "stretch", - "library", - "if", - "delay", - "college", - "plastic", - "book", - "present", - "use", - "worry", - "champion", - "goal", - "economy", - "march", - "election", - "reflection", - "midnight", - "slide", - "inflation", - "action", - "challenge", - "guitar", - "coast", - "apple", - "campaign", - "field", - "jacket", - "sense", - "way", - "visual", - "remove", - "weather", - "trash", - "cable", - "regret", - "buddy", - "beach", - "historian", - "courage", - "sympathy", - "truck", - "tension", - "permit", - "nose", - "bed", - "son", - "person", - "base", - "meat", - "usual", - "air", - "meeting", - "worth", - "game", - "independence", - "physical", - "brief", - "play", - "raise", - "board", - "she", - "key", - "writing", - "pick", - "command", - "party", - "yesterday", - "spring", - "candidate", - "physics", - "university", - "concern", - "development", - "change", - "string", - "target", - "instance", - "room", - "bitter", - "bird", - "football", - "normal", - "split", - "impression", - "wood", - "long", - "meaning", - "stock", - "cap", - "leadership", - "media", - "ambition", - "fishing", - "essay", - "salad", - "repair", - "today", - "designer", - "night", - "bank", - "drawing", - "inevitable", - "phase", - "vast", - "chip", - "anger", - "switch", - "cry", - "twist", - "personality", - "attempt", - "storage", - "being", - "preparation", - "bat", - "selection", - "white", - "technology", - "contract", - "side", - "section", - "station", - "till", - "structure", - "tongue", - "taste", - "truth", - "difficulty", - "group", - "limit", - "main", - "move", - "feeling", - "light", - "example", - "mission", - "might", - "wait", - "wheel", - "shop", - "host", - "classic", - "alternative", - "cause", - "agent", - "consist", - "table", - "airline", - "text", - "pool", - "craft", - "range", - "fuel", - "tool", - "partner", - "load", - "entrance", - "deposit", - "hate", - "article", - "video", - "summer", - "feature", - "extreme", - "mobile", - "hospital", - "flight", - "fall", - "pension", - "piano", - "fail", - "result", - "rub", - "gap", - "system", - "report", - "suck", - "ordinary", - "wind", - "nerve", - "ask", - "shine", - "note", - "line", - "mom", - "perception", - "brother", - "reference", - "bend", - "charge", - "treat", - "trick", - "term", - "homework", - "bake", - "bid", - "status", - "project", - "strategy", - "orange", - "let", - "enthusiasm", - "parent", - "concentrate", - "device", - "travel", - "poetry", - "business", - "society", - "kiss", - "end", - "vegetable", - "employ", - "schedule", - "hour", - "brave", - "focus", - "process", - "movie", - "illegal", - "general", - "coffee", - "ad", - "highway", - "chemistry", - "psychology", - "hire", - "bell", - "conference", - "relief", - "show", - "neat", - "funny", - "weight", - "quality", - "club", - "daughter", - "zone", - "touch", - "tonight", - "shock", - "burn", - "excuse", - "name", - "survey", - "landscape", - "advance", - "satisfaction", - "bread", - "disaster", - "item", - "hat", - "prior", - "shopping", - "visit", - "east", - "photo", - "home", - "idea", - "father", - "comparison", - "cat", - "pipe", - "winner", - "count", - "lake", - "fight", - "prize", - "foundation", - "dog", - "keep", - "ideal", - "fan", - "struggle", - "peak", - "safety", - "solution", - "hell", - "conclusion", - "population", - "strain", - "alarm", - "measurement", - "second", - "train", - "race", - "due", - "insurance", - "boss", - "tree", - "monitor", - "sick", - "course", - "drag", - "appointment", - "slice", - "still", - "care", - "patience", - "rich", - "escape", - "emotion", - "royal", - "female", - "childhood", - "government", - "picture", - "will", - "sock", - "big", - "gate", - "oil", - "cross", - "pin", - "improvement", - "championship", - "silly", - "help", - "sky", - "pitch", - "man", - "diamond", - "most", - "transition", - "work", - "science", - "committee", - "moment", - "fix", - "teaching", - "dig", - "specialist", - "complex", - "guide", - "people", - "dead", - "voice", - "original", - "break", - "topic", - "data", - "degree", - "reading", - "recording", - "bunch", - "reach", - "judgment", - "lie", - "regular", - "set", - "painting", - "mode", - "list", - "player", - "bear", - "north", - "wonder", - "carpet", - "heavy", - "officer", - "negative", - "clock", - "unique", - "baby", - "pain", - "assumption", - "disk", - "iron", - "bill", - "drawer", - "look", - "double", - "mistake", - "finish", - "future", - "brilliant", - "contact", - "math", - "rice", - "leave", - "restaurant", - "discount", - "sex", - "virus", - "bit", - "trust", - "event", - "wear", - "juice", - "failure", - "bug", - "context", - "mud", - "whole", - "wrap", - "intention", - "draft", - "pressure", - "cake", - "dark", - "explanation", - "space", - "angle", - "word", - "efficiency", - "management", - "habit", - "star", - "chance", - "finding", - "transportation", - "stand", - "criticism", - "flow", - "door", - "injury", - "insect", - "surprise", - "apartment", -] # pylint: disable=line-too-long - -# ISO 639-1 codes to language names. -LANGUAGE_CODES = MappingProxyType( - { - "en": "English", - "es": "Spanish", - "pt": "Portuguese", - "ar": "Arabic", - "hi": "Hindi", - "fr": "French", - "ru": "Russian", - "de": "German", - "ja": "Japanese", - "it": "Italian", - "bn": "Bengali", - "uk": "Ukrainian", - "th": "Thai", - "ur": "Urdu", - "ta": "Tamil", - "te": "Telugu", - "bg": "Bulgarian", - "ko": "Korean", - "pl": "Polish", - "he": "Hebrew", - "fa": "Persian", - "vi": "Vietnamese", - "ne": "Nepali", - "sw": "Swahili", - "kn": "Kannada", - "mr": "Marathi", - "gu": "Gujarati", - "pa": "Punjabi", - "ml": "Malayalam", - "fi": "Finnish", - } -) - -# Chinese characters -_CHINESE_CHARS_PATTERN = r"[\u4E00-\u9FFF\u3400-\u4DBF]" -# Japanese Hiragana & Katakana -_JAPANESE_CHARS_PATTERN = r"[\u3040-\u309f\u30a0-\u30ff]" -# Korean (Hangul Syllables) -_KOREAN_CHARS_PATTERN = r"[\uAC00-\uD7AF]" -_ALPHABETS = "([A-Za-z])" -_PREFIXES = "(Mr|St|Mrs|Ms|Dr)[.]" -_SUFFIXES = "(Inc|Ltd|Jr|Sr|Co)" -_STARTERS = ( - r"(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)" -) -_ACRONYMS = "([A-Z][.][A-Z][.](?:[A-Z][.])?)" -_WEBSITES = "[.](com|net|org|io|gov|edu|me)" -_DIGITS = "([0-9])" -_MULTIPLE_DOTS = r"\.{2,}" - - -# Util functions -def split_into_sentences(text): - """Split the text into sentences. - - Args: - text: A string that consists of more than or equal to one sentences. - - Returns: - A list of strings where each string is a sentence. - """ - text = " " + text + " " - text = text.replace("\n", " ") - text = re.sub(_PREFIXES, "\\1", text) - text = re.sub(_WEBSITES, "\\1", text) - text = re.sub(_DIGITS + "[.]" + _DIGITS, "\\1\\2", text) - text = re.sub( - _MULTIPLE_DOTS, - lambda match: "" * len(match.group(0)) + "", - text, - ) - if "Ph.D" in text: - text = text.replace("Ph.D.", "PhD") - text = re.sub(r"\s" + _ALPHABETS + "[.] ", " \\1 ", text) - text = re.sub(_ACRONYMS + " " + _STARTERS, "\\1 \\2", text) - text = re.sub( - _ALPHABETS + "[.]" + _ALPHABETS + "[.]" + _ALPHABETS + "[.]", - "\\1\\2\\3", - text, - ) - text = re.sub(_ALPHABETS + "[.]" + _ALPHABETS + "[.]", "\\1\\2", text) - text = re.sub(" " + _SUFFIXES + "[.] " + _STARTERS, " \\1 \\2", text) - text = re.sub(" " + _SUFFIXES + "[.]", " \\1", text) - text = re.sub(" " + _ALPHABETS + "[.]", " \\1", text) - if "”" in text: - text = text.replace(".”", "”.") - if '"' in text: - text = text.replace('."', '".') - if "!" in text: - text = text.replace('!"', '"!') - if "?" in text: - text = text.replace('?"', '"?') - text = text.replace(".", ".") - text = text.replace("?", "?") - text = text.replace("!", "!") - text = text.replace("", ".") - sentences = text.split("") - sentences = [s.strip() for s in sentences] - if sentences and not sentences[-1]: - sentences = sentences[:-1] - return sentences - - -def count_words(text): - """Counts the number of words.""" - tokenizer = nltk.tokenize.RegexpTokenizer(r"\w+") - tokens = tokenizer.tokenize(text) - num_words = len(tokens) - return num_words - - -def split_chinese_japanese_hindi(lines: str) -> Iterable[str]: - """ - Split Chinese and Japanese text into sentences. - From https://stackoverflow.com/questions/27441191/splitting-chinese-document-into-sentences - Special question/exclamation marks were added upon inspection of our raw data, - Also supports multiple lines. - The separator for hindi is '।' - """ - for line in lines.splitlines(): - for sent in re.findall( - r"[^!?。\.\!\?\!\?\.\n।]+[!?。\.\!\?\!\?\.\n।]?", - line.strip(), - flags=re.U, - ): - yield sent - - -def count_words_cjk(text: str) -> int: - """Counts the number of words for Chinese and Japanese and Korean. - Can be extended to additional languages. - Source: https://stackoverflow.com/questions/49164507/how-to-count-the-number-of-chinese-korean-and-english-words withadditional modifications - Example: - >In: count_words_cjk('こんにちは、ジェイソンさん、Jason? Nice to meet you☺ ❤') - >Out: 19 - """ - # Non alpha numeric patterns in latin and asian languages. - non_alphanumeric_patterns = ( - r"[\\.\!\?\.\/_,\{\}<>:;$%^&*(+\"\'+——!,。?、`~@#¥……():;《)《》“”()\[\]«»〔〕\-「」]+" - ) - text = re.sub(non_alphanumeric_patterns, "", text) - - emoji_cnt = emoji.emoji_count(text) # count emojis - text = emoji.replace_emoji(text, "") # remove emojis - - foreign_chars_patterns = "|".join([_CHINESE_CHARS_PATTERN, _JAPANESE_CHARS_PATTERN, _KOREAN_CHARS_PATTERN]) - asian_chars = re.findall(foreign_chars_patterns, text) - asian_chars_cnt = len(asian_chars) - non_asian_chars = re.sub(foreign_chars_patterns, " ", text) - non_asian_words_cnt = len(non_asian_chars.split()) - - return non_asian_words_cnt + asian_chars_cnt + emoji_cnt - - -@functools.lru_cache(maxsize=None) -def _get_sentence_tokenizer(): - return nltk.data.load("nltk:tokenizers/punkt/english.pickle") - - -def count_sentences(text): - """Count the number of sentences.""" - tokenizer = _get_sentence_tokenizer() - tokenized_sentences = tokenizer.tokenize(text) - return len(tokenized_sentences) - - -def get_langid(text: str, lid_path: Optional[str] = None) -> str: - line_langs: List[str] = [] - lines = [line.strip() for line in text.split("\n") if len(line.strip()) >= 4] - - for line in lines: - try: - line_langs.append(langdetect.detect(line)) - except langdetect.LangDetectException as e: - logger.info("Unable to detect language for text %s due to %s", line, e) # refex: disable=pytotw.037 - - if len(line_langs) == 0: - return "en" - # select the text language to be the most commonly predicted language of the lines. - return collections.Counter(line_langs).most_common(1)[0][0] - - -def generate_keywords(num_keywords): - """Randomly generates a few keywords.""" - return random.sample(WORD_LIST, k=num_keywords) - - -"""Library of instructions""" -_InstructionArgsDtype = Optional[Dict[str, Union[int, str, Sequence[str]]]] - -_LANGUAGES = LANGUAGE_CODES - -# The relational operation for comparison. -_COMPARISON_RELATION = ("less than", "at least") - -# The maximum number of sentences. -_MAX_NUM_SENTENCES = 20 - -# The number of placeholders. -_NUM_PLACEHOLDERS = 4 - -# The number of bullet lists. -_NUM_BULLETS = 5 - -# The options of constrained response. -_CONSTRAINED_RESPONSE_OPTIONS = ( - "My answer is yes.", - "My answer is no.", - "My answer is maybe.", -) - -# The options of starter keywords. -_STARTER_OPTIONS = ( - "I would say", - "My answer is", - "I believe", - "In my opinion", - "I think", - "I reckon", - "I feel", - "From my perspective", - "As I see it", - "According to me", - "As far as I'm concerned", - "To my understanding", - "In my view", - "My take on it is", - "As per my perception", -) - -# The options of ending keywords. -# TODO(jeffreyzhou) add more ending options -_ENDING_OPTIONS = ("Any other questions?", "Is there anything else I can help with?") - -# The number of highlighted sections. -_NUM_HIGHLIGHTED_SECTIONS = 4 - -# The section spliter. -_SECTION_SPLITER = ("Section", "SECTION") - -# The number of sections. -_NUM_SECTIONS = 5 - -# The number of paragraphs. -_NUM_PARAGRAPHS = 5 - -# The postscript marker. -_POSTSCRIPT_MARKER = ("P.S.", "P.P.S") - -# The number of keywords. -_NUM_KEYWORDS = 2 - -# The occurrences of a single keyword. -_KEYWORD_FREQUENCY = 3 - -# The occurrences of a single letter. -_LETTER_FREQUENCY = 10 - -# The occurrences of words with all capital letters. -_ALL_CAPITAL_WORD_FREQUENCY = 20 - -# The number of words in the response. -_NUM_WORDS_LOWER_LIMIT = 100 -_NUM_WORDS_UPPER_LIMIT = 500 - - -class Instruction: - """An instruction template.""" - - def __init__(self, instruction_id): - self.id = instruction_id - - def build_description(self, **kwargs): - raise NotImplementedError("`build_description` not implemented.") - - def get_instruction_args(self): - raise NotImplementedError("`get_instruction_args` not implemented.") - - def get_instruction_args_keys(self): - raise NotImplementedError("`get_instruction_args_keys` not implemented.") - - def check_following(self, value): - raise NotImplementedError("`check_following` not implemented.") - - -class ResponseLanguageChecker(Instruction): - """Check the language of the entire response.""" - - def build_description(self, *, language=None): - """Build the instruction description. - - Args: - language: A string representing the expected language of the response. The - language has to comply to the 97 types defined in - `langid.py` (https://pypi.org/project/langid/1.1.5/), which follows - ISO 639-1 codes (https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes); - for example, `en` for English, `zh` for Chinese, `fr` for French. - - Returns: - A string representing the instruction description. - """ - self._language = language - if self._language is None: - self._language = random.choice(list(_LANGUAGES.keys())) - - self._description_pattern = ( - "Your ENTIRE response should be in {language} language, no other " + "language is allowed." - ) - return self._description_pattern.format(language=_LANGUAGES[self._language]) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"language": self._language} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["language"] - - def check_following(self, value): - """Check if the language of the entire response follows the instruction. - - Args: - value: A string representing the response. - - Returns: - True if the language of `value` follows instruction; otherwise False. - """ - assert isinstance(value, str) - - try: - return langdetect.detect(value) == self._language - except langdetect.LangDetectException as e: - # Count as instruction is followed. - logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 - return True - - -class NumberOfSentences(Instruction): - """Check the number of sentences.""" - - def build_description(self, *, num_sentences=None, relation=None): - """Build the instruction description. - - Args: - num_sentences: An integer specifying the number of sentences as a - threshold. - relation: A string in (`less than`, `at least`), defining the relational - operator for comparison. - Two relational comparisons are supported for now: - if 'less than', the actual number of sentences < the threshold; - if 'at least', the actual number of sentences >= the threshold. - - Returns: - A string representing the instruction description. - """ - # The number of sentences as a threshold for comparison. - self._num_sentences_threshold = num_sentences - if self._num_sentences_threshold is None or self._num_sentences_threshold < 0: - self._num_sentences_threshold = random.randint(1, _MAX_NUM_SENTENCES) - - if relation is None: - self._comparison_relation = random.choice(_COMPARISON_RELATION) - elif relation not in _COMPARISON_RELATION: - raise ValueError( - f"The supported relation for comparison must be in {_COMPARISON_RELATION}, but {relation} is given." - ) - else: - self._comparison_relation = relation - - self._description_pattern = "Your response should contain {relation} {num_sentences} sentences." - return self._description_pattern.format( - relation=self._comparison_relation, - num_sentences=self._num_sentences_threshold, - ) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return { - "num_sentences": self._num_sentences_threshold, - "relation": self._comparison_relation, - } - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["num_sentences", "relation"] - - def check_following(self, value): - """Check if the number of sentences follows the instruction. - - Args: - value: A string representing the response. - - Returns: - True if the response follows the instruction. - - Raise: - ValueError if the string in `instruction_args` is not in - [`less_than`, `at_least`]. - """ - lang = get_langid(value) - if lang == "th": - # Counting Newline also as a new sentence: - num_sentences = sum([len(sent_tokenize_thai(line)) for line in value.splitlines()]) - elif lang in ["zh", "zh-cn", "zh-tw", "ja", "hi"]: - num_sentences = len(list(split_chinese_japanese_hindi(value))) - else: - num_sentences = count_sentences(value) - if self._comparison_relation == _COMPARISON_RELATION[0]: - return num_sentences < self._num_sentences_threshold - elif self._comparison_relation == _COMPARISON_RELATION[1]: - return num_sentences >= self._num_sentences_threshold - - -class PlaceholderChecker(Instruction): - """Check the placeholders in template writing.""" - - def build_description(self, *, num_placeholders=None): - """Build the instruction description. - - Args: - num_placeholders: An integer denoting the minimum number of - placeholders required in the response. - - Returns: - A string representing the instruction description. - """ - self._num_placeholders = num_placeholders - if self._num_placeholders is None or self._num_placeholders < 0: - self._num_placeholders = random.randint(1, _NUM_PLACEHOLDERS) - self._description_pattern = ( - "The response must contain at least {num_placeholders} placeholders " - + "represented by square brackets, such as [address]." - ) - return self._description_pattern.format(num_placeholders=self._num_placeholders) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"num_placeholders": self._num_placeholders} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["num_placeholders"] - - def check_following(self, value): - """Check if the number of placeholders follows the instruction. - - Args: - value: A string representing the response. - - Returns: - True if the actual number of placeholders in the response is greater than - or equal to `num_placeholders`; otherwise, False. - """ - placeholders = re.findall(r"\[.*?\]", value) - num_placeholders = len(placeholders) - return num_placeholders >= self._num_placeholders - - -class BulletListChecker(Instruction): - """Checks the bullet list in the prompt.""" - - def build_description(self, *, num_bullets=None): - """Build the instruction description. - - Args: - num_bullets: An integer specifying the exact number of bullet lists - that is required to appear in the response. - - Returns: - A string representing the instruction description. - """ - self._num_bullets = num_bullets - if self._num_bullets is None or self._num_bullets < 0: - self._num_bullets = random.randint(1, _NUM_BULLETS) - self._description_pattern = ( - "Your answer must contain exactly {num_bullets} bullet points. " - + "Use the markdown bullet points such as:\n" - + "* This is point 1. \n" - + "* This is point 2" - ) - return self._description_pattern.format(num_bullets=self._num_bullets) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"num_bullets": self._num_bullets} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["num_bullets"] - - def check_following(self, value): - r"""Check if the number of bullet lists meets the requirement. - - Args: - value: A string representing the response. The response is expected to - contain some bullet lists that start with `\*`. - - Returns: - True if the actual number of bullet lists in the response meets the - requirement. - """ - bullet_lists = re.findall(r"^\s*\*[^\*].*$", value, flags=re.MULTILINE) - bullet_lists_2 = re.findall(r"^\s*-.*$", value, flags=re.MULTILINE) - num_bullet_lists = len(bullet_lists) + len(bullet_lists_2) - return num_bullet_lists == self._num_bullets - - -class ConstrainedResponseChecker(Instruction): - """Checks the constrained response.""" - - def build_description(self): - """Build the instruction description.""" - # A sequence of string(s) representing the options of the expected response. - self._constrained_responses = _CONSTRAINED_RESPONSE_OPTIONS - self._description_pattern = "Answer with one of the following options: {response_options}" - return self._description_pattern.format(response_options=self._constrained_responses) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return None - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return [] - - def check_following(self, value): - """Checks if the response matches the constrained options. - - Args: - value: A string representing the response. - - Returns: - True if the actual response contains one of the options in the constrained - responses; otherwise False. - """ - value = value.strip() - for constrained_response in self._constrained_responses: - if constrained_response in value: - return True - return False - - -class ConstrainedStartChecker(Instruction): - """Checks the response start.""" - - def build_description(self, *, starter=None): - """Build the instruction description. - - Args: - starter: A string representing the keyward that the response should start - with. - - Returns: - A string representing the instruction description. - """ - self._starter = starter.strip() if isinstance(starter, str) else starter - if self._starter is None: - self._starter = random.choice(_STARTER_OPTIONS) - self._description_pattern = ( - "During the conversation, when it is your turn, " + "please always start with {starter}" - ) - return self._description_pattern.format(starter=self._starter) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"starter": self._starter} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["starter"] - - def check_following(self, value): - """Checks if the response starts with the constrained keyword or phrase. - - Args: - value: A string representing the response. - - Returns: - True if the response starts with the given phrase or keyword that is - contained in `instruction_args`; otherwise, False. - """ - response_pattern = r"^\s*" + self._starter + r".*$" - response_with_constrained_start = re.search(response_pattern, value, flags=re.MULTILINE) - return True if response_with_constrained_start else False - - -class HighlightSectionChecker(Instruction): - """Checks the highlighted section.""" - - def build_description(self, *, num_highlights=None): - """Build the instruction description. - - Args: - num_highlights: An integer specifying the minimum number of highlighted - sections. - - Returns: - A string representing the instruction description. - """ - self._num_highlights = num_highlights - if self._num_highlights is None or self._num_highlights < 0: - self._num_highlights = random.randint(1, _NUM_HIGHLIGHTED_SECTIONS) - - self._description_pattern = ( - "Highlight at least {num_highlights} sections in your answer with " - + "markdown, i.e. *highlighted section*." - ) - - return self._description_pattern.format(num_highlights=self._num_highlights) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"num_highlights": self._num_highlights} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["num_highlights"] - - def check_following(self, value): - """Checks if the number of highlighted sections meets the requirement. - - Args: - value: a string repesenting the response. The response is expected to - contain highlighted sections in the format of *highlighted*. - - Returns: - True if the actual number of highlighted sections in the format of - *highlighed sections* meets the minimum requirement; otherwise False. - """ - num_highlights = 0 - highlights = re.findall(r"\*[^\n\*]*\*", value) - double_highlights = re.findall(r"\*\*[^\n\*]*\*\*", value) - for highlight in highlights: - if highlight.strip("*").strip(): - num_highlights += 1 - for highlight in double_highlights: - if highlight.removeprefix("**").removesuffix("**").strip(): - num_highlights += 1 - - return num_highlights >= self._num_highlights - - -class SectionChecker(Instruction): - """Checks the sections.""" - - def build_description(self, *, section_spliter=None, num_sections=None): - """Build the instruction description. - - Args: - section_spliter: A string represents the section spliter keyword that - marks a new section, i.e., `Section` or `SECTION`. - num_sections: An integer specifying the number of sections. - - Returns: - A string representing the instruction description. - """ - self._section_spliter = section_spliter.strip() if isinstance(section_spliter, str) else section_spliter - if self._section_spliter is None: - self._section_spliter = random.choice(_SECTION_SPLITER) - - self._num_sections = num_sections - if self._num_sections is None or self._num_sections < 0: - self._num_sections = random.randint(1, _NUM_SECTIONS) - - self._description_pattern = ( - "Your response must have {num_sections} sections. Mark the beginning " - + "of each section with {section_spliter} X, such as:\n" - + "{section_spliter} 1\n" - + "[content of section 1]\n" - + "{section_spliter} 2\n" - + "[content of section 2]" - ) - - return self._description_pattern.format(num_sections=self._num_sections, section_spliter=self._section_spliter) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return { - "section_spliter": self._section_spliter, - "num_sections": self._num_sections, - } - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["section_spliter", "num_sections"] - - def check_following(self, value): - """Checks the response contains multiple sections. - - Args: - value: A string representing the response. The response is expected - to contain multiple sections (number of sections is greater than 1). - A new section starts with `Section 1`, where the number denotes the - section index. - - Returns: - True if the number of sections in the response is greater than or equal to - the minimum number of sections; otherwise, False. - """ - section_splitter_patten = r"\s?" + self._section_spliter + r"\s?\d+\s?" - sections = re.split(section_splitter_patten, value) - num_sections = len(sections) - 1 - return num_sections >= self._num_sections - - -class ParagraphChecker(Instruction): - """Checks the paragraphs.""" - - def build_description(self, *, num_paragraphs=None): - """Build the instruction description. - - Args: - num_paragraphs: An integer specifying the number of paragraphs. - - Returns: - A string representing the instruction description. - """ - self._num_paragraphs = num_paragraphs - if self._num_paragraphs is None or self._num_paragraphs < 0: - self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS) - - self._description_pattern = ( - "There should be {num_paragraphs} paragraphs. " + "Paragraphs are separated with the markdown divider: ***" - ) - - return self._description_pattern.format(num_paragraphs=self._num_paragraphs) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"num_paragraphs": self._num_paragraphs} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["num_paragraphs"] - - def check_following(self, value): - """Checks the response contains required number of paragraphs. - - Args: - value: A string representing the response. The response may contain - paragraphs that are separated by the markdown divider: `***`. - - Returns: - True if the actual number of paragraphs is the same as required; - otherwise, False. - """ - paragraphs = re.split(r"\s?\*\*\*\s?", value) - num_paragraphs = len(paragraphs) - - for index, paragraph in enumerate(paragraphs): - if not paragraph.strip(): - if index == 0 or index == len(paragraphs) - 1: - num_paragraphs -= 1 - else: - return False - - return num_paragraphs == self._num_paragraphs - - -class PostscriptChecker(Instruction): - """Checks the postscript.""" - - def build_description(self, *, postscript_marker=None): - """Build the instruction description. - - Args: - postscript_marker: A string containing the keyword that marks the start - of the postscript section. - - Returns: - A string representing the instruction description. - """ - self._postscript_marker = postscript_marker.strip() if isinstance(postscript_marker, str) else postscript_marker - if self._postscript_marker is None: - self._postscript_marker = random.choice(_POSTSCRIPT_MARKER) - - self._description_pattern = ( - "At the end of your response, please explicitly add a postscript " + "starting with {postscript}" - ) - - return self._description_pattern.format(postscript=self._postscript_marker) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"postscript_marker": self._postscript_marker} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["postscript_marker"] - - def check_following(self, value): - """Checks if the response follows the postscript format. - - Args: - value: a string representing the response. The response is expected to - contain a postscript section. - - Returns: - True if the response contains a postscript section starting with - the keyword containing in the `instruction_args`; otherwise False. - """ - value = value.lower() - if self._postscript_marker == "P.P.S": - postscript_pattern = r"\s*p\.\s?p\.\s?s.*$" - elif self._postscript_marker == "P.S.": - postscript_pattern = r"\s*p\.\s?s\..*$" - else: - postscript_pattern = r"\s*" + self._postscript_marker.lower() + r".*$" - postscript = re.findall(postscript_pattern, value, flags=re.MULTILINE) - return True if postscript else False - - -class RephraseChecker(Instruction): - """Checks the repharse.""" - - def build_description(self, *, original_message): - """Build the instruction description. - - Args: - original_message: A string representing the original message. The - rephrased response should only change its words/sentences in between - its two asterisks, for example, *change me*. Both original and rephrased - messages should contain the changes in the form of *change me*. - - Returns: - A string representing the instruction description. - """ - if not self.is_change(original_message): - raise ValueError(f"Message {original_message} does not contain changes in the form of *change me*.") - - self._reference_without_change = original_message - self._description = ( - "Rephrasing: Your rephrased response should only" - + "change the words/sentences in between two asterisks" - + "such as *change me*." - ) - return self._description - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"original_message": self._reference_without_change} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["original_message"] - - def check_following(self, value): - r"""Checks if the rephrasing follows the instruction. - - Args: - value: A string representing the response, which is expected to rephras - the string of `instruction_args`. - - Returns: - True if `value` and `instruction_args` only differ by the words/sentences - in between two asterisks such as *change me*; otherwise, False. - """ - - if not self.is_change(value): - raise ValueError(f"value {value} does not contain changes in the form of *change me*.") - - response_without_changes = self.strip_changes(value) - reference_without_changes = self.strip_changes(self._reference_without_change) - - return response_without_changes == reference_without_changes - - def is_change(self, response): - """Check if there is change in the response in the form of *change me*.""" - return re.search(r"\*.*\*", response) - - def strip_changes(self, response): - """Strips off the changes.""" - return re.sub(r"\*.*\*", "", response) - - -class KeywordChecker(Instruction): - """Check the exisitence of certain keywords.""" - - def build_description(self, *, keywords=None): - """Build the instruction description. - - Args: - keywords: A sequence of strings representing the keywords that are - expected in the response. - - Returns: - A string representing the instruction description. - """ - - if not keywords: - self._keywords = generate_keywords(num_keywords=_NUM_KEYWORDS) - else: - self._keywords = keywords - self._keywords = sorted(self._keywords) - - self._description_pattern = "Include keywords {keywords} in the response." - - return self._description_pattern.format(keywords=self._keywords) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"keywords": self._keywords} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["keywords"] - - def check_following(self, value): - """Check if the response contain the expected keywords.""" - for keyword in self._keywords: - if not re.search(keyword, value, flags=re.IGNORECASE): - return False - return True - - -class KeywordFrequencyChecker(Instruction): - """Check the keyword frequency.""" - - def build_description(self, *, keyword=None, frequency=None, relation=None): - """Build the instruction description. - - Args: - keyword: A string representing a keyword that is expected in the response. - frequency: An integer specifying the number of times `keyword` is expected - to appear in the response. - relation: A string in (`less than`, `at least`), defining the relational - operator for comparison. - Two relational comparisons are supported for now: - if 'less than', the actual number of occurrences < frequency; - if 'at least', the actual number of occurrences >= frequency. - - Returns: - A string representing the instruction description. - """ - if not keyword: - self._keyword = generate_keywords(num_keywords=1)[0] - else: - self._keyword = keyword.strip() - - self._frequency = frequency - if self._frequency is None or self._frequency < 0: - self._frequency = random.randint(1, _KEYWORD_FREQUENCY) - - if relation is None: - self._comparison_relation = random.choice(_COMPARISON_RELATION) - elif relation not in _COMPARISON_RELATION: - raise ValueError( - f"The supported relation for comparison must be in {_COMPARISON_RELATION}, but {relation} is given." - ) - else: - self._comparison_relation = relation - - self._description_pattern = ( - "In your response, the word {keyword} should appear {relation} " + "{frequency} times." - ) - - return self._description_pattern.format( - keyword=self._keyword, - relation=self._comparison_relation, - frequency=self._frequency, - ) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return { - "keyword": self._keyword, - "frequency": self._frequency, - "relation": self._comparison_relation, - } - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["keyword", "frequency", "relation"] - - def check_following(self, value): - """Checks if the response contain the keyword with required frequency.""" - actual_occurrences = len(re.findall(self._keyword, value, flags=re.IGNORECASE)) - - if self._comparison_relation == _COMPARISON_RELATION[0]: - return actual_occurrences < self._frequency - elif self._comparison_relation == _COMPARISON_RELATION[1]: - return actual_occurrences >= self._frequency - - -class NumberOfWords(Instruction): - """Checks the number of words.""" - - def build_description(self, *, num_words=None, relation=None): - """Build the instruction description. - - Args: - num_words: An integer specifying the number of words contained in the - response. - relation: A string in (`less than`, `at least`), defining the relational - operator for comparison. - Two relational comparisons are supported for now: - if 'less than', the actual number of words < num_words; - if 'at least', the actual number of words >= num_words. - - Returns: - A string representing the instruction description. - """ - - self._num_words = num_words - if self._num_words is None or self._num_words < 0: - self._num_words = random.randint(_NUM_WORDS_LOWER_LIMIT, _NUM_WORDS_UPPER_LIMIT) - - if relation is None: - self._comparison_relation = random.choice(_COMPARISON_RELATION) - elif relation not in _COMPARISON_RELATION: - raise ValueError( - f"The supported relation for comparison must be in {_COMPARISON_RELATION}, but {relation} is given." - ) - else: - self._comparison_relation = relation - - self._description_pattern = "Answer with {relation} {num_words} words." - - return self._description_pattern.format(relation=self._comparison_relation, num_words=self._num_words) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"num_words": self._num_words, "relation": self._comparison_relation} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["num_words", "relation"] - - def check_following(self, value): - """Checks if the response contains the expected number of words.""" - lang = get_langid(value) - if lang == "th": - num_words = len(word_tokenize_thai(value)) - elif lang in ["zh", "zh-cn", "zh-tw", "ja", "ko"]: - num_words = count_words_cjk(value) - else: - num_words = count_words(value) - - if self._comparison_relation == _COMPARISON_RELATION[0]: - return num_words < self._num_words - elif self._comparison_relation == _COMPARISON_RELATION[1]: - return num_words >= self._num_words - - -class JsonFormat(Instruction): - """Check the Json format.""" - - def build_description(self): - self._description_pattern = ( - "Entire output should be wrapped in JSON format. You can use markdown ticks such as ```." - ) - return self._description_pattern - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return None - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return [] - - def check_following(self, value): - value = ( - value.strip() - .removeprefix("```json") - .removeprefix("```Json") - .removeprefix("```JSON") - .removeprefix("```") - .removesuffix("```") - .strip() - ) - try: - json.loads(value) - except ValueError as _: - return False - return True - - -class ParagraphFirstWordCheck(Instruction): - """Check the paragraph and the first word of the nth paragraph.""" - - def build_description(self, num_paragraphs=None, nth_paragraph=None, first_word=None): - r"""Build the instruction description. - - Args: - num_paragraphs: An integer indicating the number of paragraphs expected - in the response. A paragraph is a subset of the string that is - expected to be separated by '\n\n'. - nth_paragraph: An integer indicating the paragraph number that we look at. - Note that n starts from 1. - first_word: A string that represent the first word of the bth paragraph. - - Returns: - A string representing the instruction description. - """ - self._num_paragraphs = num_paragraphs - if self._num_paragraphs is None or self._num_paragraphs < 0: - self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS) - - self._nth_paragraph = nth_paragraph - if self._nth_paragraph is None or self._nth_paragraph <= 0 or self._nth_paragraph > self._num_paragraphs: - self._nth_paragraph = random.randint(1, self._num_paragraphs + 1) - - self._first_word = first_word - if self._first_word is None: - self._first_word = generate_keywords(num_keywords=1)[0] - self._first_word = self._first_word.lower() - - self._description_pattern = ( - "There should be {num_paragraphs} paragraphs. " - + "Paragraphs and only paragraphs are separated with each other by two " - + "new lines as if it was '\\n\\n' in python. " - + "Paragraph {nth_paragraph} must start with word {first_word}." - ) - - return self._description_pattern.format( - num_paragraphs=self._num_paragraphs, - nth_paragraph=self._nth_paragraph, - first_word=self._first_word, - ) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return { - "num_paragraphs": self._num_paragraphs, - "nth_paragraph": self._nth_paragraph, - "first_word": self._first_word, - } - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["num_paragraphs", "nth_paragraph", "first_word"] - - def check_following(self, value): - """Checks for required number of paragraphs and correct first word. - - Args: - value: a string representing the response. The response may contain - paragraphs that are separated by two new lines and the first word of - the nth paragraph will have to match a specified word. - - Returns: - True if the number of paragraphs is the same as required and the first - word of the specified paragraph is the same as required. Otherwise, false. - """ - - paragraphs = re.split(r"\n\n", value) - num_paragraphs = len(paragraphs) - - for paragraph in paragraphs: - if not paragraph.strip(): - num_paragraphs -= 1 - - # check that index doesn't go out of bounds - if self._nth_paragraph <= num_paragraphs: - paragraph = paragraphs[self._nth_paragraph - 1].strip() - if not paragraph: - return False - else: - return False - - first_word = "" - punctuation = {".", ",", "?", "!", "'", '"'} - - # get first word and remove punctuation - word = paragraph.split()[0].strip() - word = word.lstrip("'") - word = word.lstrip('"') - - for letter in word: - if letter in punctuation: - break - first_word += letter.lower() - - return num_paragraphs == self._num_paragraphs and first_word == self._first_word - - -class KeySentenceChecker(Instruction): - """Check the existence of certain key sentences.""" - - def build_description(self, key_sentences=None, num_sentences=None): - """Build the instruction description. - - Args: - key_sentences: A sequences of strings representing the key sentences that - are expected in the response. - num_sentences: The number of key sentences that are expected to be seen in - the response. - - Returns: - A string representing the instruction description. - """ - - if not key_sentences: - self._key_sentences = {["For now, this is fine."]} - else: - self._key_sentences = key_sentences - - if not num_sentences: - self._num_sentences = random.randint(1, len(self._key_sentences)) - else: - self._num_sentences = num_sentences - - self._description_pattern = "Include {num_sentences} of the following sentences {key_sentences}" - - return self._description_pattern.format(num_sentences=self._num_sentences, key_sentences=self._key_sentences) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return { - "num_sentences": self._num_sentences, - "key_sentences": list(self._key_sentences), - } - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["num_sentences", "key_sentences"] - - def check_following(self, value): - """Checks if the response contains the expected key sentences.""" - count = 0 - sentences = split_into_sentences(value) - for sentence in self._key_sentences: - if sentence in sentences: - count += 1 - - return count == self._num_sentences - - -class ForbiddenWords(Instruction): - """Checks that specified words are not used in response.""" - - def build_description(self, forbidden_words=None): - """Build the instruction description. - - Args: - forbidden_words: A sequences of strings respresenting words that are not - allowed in the response. - - Returns: - A string representing the instruction description. - """ - - if not forbidden_words: - self._forbidden_words = generate_keywords(num_keywords=_NUM_KEYWORDS) - else: - self._forbidden_words = list(set(forbidden_words)) - self._forbidden_words = sorted(self._forbidden_words) - self._description_pattern = "Do not include keywords {forbidden_words} in the response." - - return self._description_pattern.format(forbidden_words=self._forbidden_words) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"forbidden_words": self._forbidden_words} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["forbidden_words"] - - def check_following(self, value): - """Check if the response does not contain the expected keywords.""" - for word in self._forbidden_words: - if re.search(r"\b" + word + r"\b", value, flags=re.IGNORECASE): - return False - return True - - -class RephraseParagraph(Instruction): - """Checks that the paragraph is rephrased.""" - - def build_description(self, *, original_paragraph, low, high): - """Builds the instruction description. - - Args: - original_paragraph: A string presenting the original paragraph. The - rephrases response should have betweeb low-high words in common. - low: An integer presenting the lower bound of similar words. - high: An integer representing the upper bound of similar words. - - Returns: - A string representing the instruction description. - """ - self._original_paragraph = original_paragraph - self._low = low - self._high = high - - self._description = ( - "Rephrase the following paragraph: " - + "{original_paragraph}\nYour response should have " - + "between {low} and {high} of the same words. " - + "Words are the same if and only if all of the " - + "letters, ignoring cases, are the same. For " - + "example, 'run' is the same as 'Run' but different " - + "to 'ran'." - ) - - return self._description.format(original_paragraph=original_paragraph, low=self._low, high=self._high) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return { - "original_paragraph": self._original_paragraph, - "low": self._low, - "high": self._high, - } - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["original_paragraph", "low", "high"] - - def check_following(self, value): - val_words = re.findall(r"\w+", value.lower()) - original_words = re.findall(r"\w+", self._original_paragraph.lower()) - similar_words = 0 - - dict_val = collections.Counter(val_words) - dict_original = collections.Counter(original_words) - - for word in dict_original: - similar_words += min(dict_original[word], dict_val[word]) - - return similar_words >= self._low and similar_words <= self._high - - -class TwoResponsesChecker(Instruction): - """Check that two responses were given.""" - - def build_description(self): - """Build the instruction description.""" - self._description_pattern = ( - "Give two different responses. Responses and only responses should" - " be separated by 6 asterisk symbols: ******." - ) - return self._description_pattern - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return None - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return [] - - def check_following(self, value): - """Checks if the response has two different answers. - - Args: - value: A string representing the response. - - Returns: - True if two responses are detected and false otherwise. - """ - valid_responses = list() - responses = value.split("******") - for index, response in enumerate(responses): - if not response.strip(): - if index != 0 and index != len(responses) - 1: - return False - else: - valid_responses.append(response) - return len(valid_responses) == 2 and valid_responses[0].strip() != valid_responses[1].strip() - - -class RepeatPromptThenAnswer(Instruction): - """Checks that Prompt is first repeated then answered.""" - - def build_description(self, *, prompt_to_repeat=None): - """Build the instruction description. - - Args: - prompt_to_repeat: The prompt that is meant to be repeated. - - Returns: - A string representing the instruction description. - """ - if not prompt_to_repeat: - raise ValueError("prompt_to_repeat must be set.") - else: - self._prompt_to_repeat = prompt_to_repeat - self._description_pattern = ( - "First repeat the request word for word without change," - " then give your answer (1. do not say any words or characters" - " before repeating the request; 2. the request you need to repeat" - " does not include this sentence)" - ) - return self._description_pattern - - def get_instruction_args(self): - return {"prompt_to_repeat": self._prompt_to_repeat} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["prompt_to_repeat"] - - def check_following(self, value): - if value.strip().lower().startswith(self._prompt_to_repeat.strip().lower()): - return True - return False - - -class EndChecker(Instruction): - """Checks that the prompt ends with a given phrase.""" - - def build_description(self, *, end_phrase=None): - """Build the instruction description. - - Args: - end_phrase: A string representing the phrase the response should end with. - - Returns: - A string representing the instruction description. - """ - self._end_phrase = end_phrase.strip() if isinstance(end_phrase, str) else end_phrase - if self._end_phrase is None: - self._end_phrase = random.choice(_ENDING_OPTIONS) - self._description_pattern = ( - "Finish your response with this exact phrase {ender}. No other words should follow this phrase." - ) - return self._description_pattern.format(ender=self._end_phrase) - - def get_instruction_args(self): - return {"end_phrase": self._end_phrase} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["end_phrase"] - - def check_following(self, value): - """Checks if the response ends with the expected phrase.""" - value = value.strip().strip('"').lower() - self._end_phrase = self._end_phrase.strip().lower() - return value.endswith(self._end_phrase) - - -class TitleChecker(Instruction): - """Checks the response for a title.""" - - def build_description(self): - """Build the instruction description.""" - self._description_pattern = ( - "Your answer must contain a title, wrapped in double angular brackets, such as <>." - ) - return self._description_pattern - - def get_instruction_args(self): - return None - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return [] - - def check_following(self, value): - """Checks if the response contains a title.""" - pattern = r"<<[^\n]+>>" - re_pattern = re.compile(pattern) - titles = re.findall(re_pattern, value) - - for title in titles: - if title.lstrip("<").rstrip(">").strip(): - return True - return False - - -class LetterFrequencyChecker(Instruction): - """Checks letter frequency.""" - - def build_description(self, *, letter=None, let_frequency=None, let_relation=None): - """Build the instruction description. - - Args: - letter: A string representing a letter that is expected in the response. - let_frequency: An integer specifying the number of times `keyword` is - expected to appear in the response. - let_relation: A string in (`less than`, `at least`), defining the - relational operator for comparison. Two relational comparisons are - supported for now; if 'less than', the actual number of - occurrences < frequency; if 'at least', the actual number of - occurrences >= frequency. - - Returns: - A string representing the instruction description. - """ - if not letter or len(letter) > 1 or ord(letter.lower()) < 97 or ord(letter.lower()) > 122: - self._letter = random.choice(list(string.ascii_letters)) - else: - self._letter = letter.strip() - self._letter = self._letter.lower() - - self._frequency = let_frequency - if self._frequency is None or self._frequency < 0: - self._frequency = random.randint(1, _LETTER_FREQUENCY) - - if let_relation is None: - self._comparison_relation = random.choice(_COMPARISON_RELATION) - elif let_relation not in _COMPARISON_RELATION: - raise ValueError( - f"The supported relation for comparison must be in {_COMPARISON_RELATION}, but {let_relation} is given." - ) - else: - self._comparison_relation = let_relation - - self._description_pattern = ( - "In your response, the letter {letter} should appear {let_relation} {let_frequency} times." - ) - - return self._description_pattern.format( - letter=self._letter, - let_frequency=self._frequency, - let_relation=self._comparison_relation, - ) - - def get_instruction_args(self): - """Returns the keyword args of build description.""" - return { - "letter": self._letter, - "let_frequency": self._frequency, - "let_relation": self._comparison_relation, - } - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["letter", "let_frequency", "let_relation"] - - def check_following(self, value): - """Checks that the response contains the letter at the right frequency.""" - value = value.lower() - letters = collections.Counter(value) - - if self._comparison_relation == _COMPARISON_RELATION[0]: - return letters[self._letter] < self._frequency - else: - return letters[self._letter] >= self._frequency - - -class CapitalLettersEnglishChecker(Instruction): - """Checks that the response is in english and is in all capital letters.""" - - def build_description(self): - """Build the instruction description.""" - self._description_pattern = "Your entire response should be in English, and in all capital letters." - return self._description_pattern - - def get_instruction_args(self): - return None - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return [] - - def check_following(self, value): - """Checks that the response is in English and in all capital letters.""" - assert isinstance(value, str) - - try: - return value.isupper() and langdetect.detect(value) == "en" - except langdetect.LangDetectException as e: - # Count as instruction is followed. - logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 - return True - - -class LowercaseLettersEnglishChecker(Instruction): - """Checks that the response is in english and is in all lowercase letters.""" - - def build_description(self): - """Build the instruction description.""" - self._description_pattern = ( - "Your entire response should be in English, and in all lowercase letters. No capital letters are allowed." - ) - return self._description_pattern - - def get_instruction_args(self): - return None - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return [] - - def check_following(self, value): - """Checks that the response is in English and in all lowercase letters.""" - assert isinstance(value, str) - - try: - return value.islower() and langdetect.detect(value) == "en" - except langdetect.LangDetectException as e: - # Count as instruction is followed. - logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 - return True - - -class CommaChecker(Instruction): - """Checks the response for no commas.""" - - def build_description(self, **kwargs): - """Build the instruction description.""" - self._description_pattern = "In your entire response, refrain from the use of any commas." - return self._description_pattern - - def get_instruction_args(self): - return None - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return [] - - def check_following(self, value): - """Checks that the response does not contain commas.""" - return not re.search(r"\,", value) - - -class CapitalWordFrequencyChecker(Instruction): - """Checks frequency of words with all capital letters.""" - - def build_description( - self, - capital_frequency=None, - capital_relation=None, - ): - """Build the instruction description. - - Args: - capital_frequency: An integer that represents the number of words that - should be in all capital letters. - capital_relation: A string that is 'at least' or 'at most' that refers to - the frequency. - - Returns: - A string representing the instruction description. - """ - self._frequency = capital_frequency - if self._frequency is None: - self._frequency = random.randint(1, _ALL_CAPITAL_WORD_FREQUENCY) - - self._comparison_relation = capital_relation - if capital_relation is None: - self._comparison_relation = random.choice(_COMPARISON_RELATION) - elif capital_relation not in _COMPARISON_RELATION: - raise ValueError( - "The supported relation for comparison must be in " - f"{_COMPARISON_RELATION}, but {capital_relation} is given." - ) - - self._description_pattern = ( - "In your response, words with all capital letters should appear {relation} {frequency} times." - ) - - return self._description_pattern.format(frequency=self._frequency, relation=self._comparison_relation) - - def get_instruction_args(self): - """Returns the keyword args of build description.""" - return { - "capital_frequency": self._frequency, - "capital_relation": self._comparison_relation, - } - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["capital_frequency", "capital_relation"] - - def check_following(self, value): - """Checks the frequency of words with all capital letters.""" - # Hyphenated words will count as one word - nltk.download("punkt_tab") - words = nltk.word_tokenize(value) - capital_words = [word for word in words if word.isupper()] - - capital_words = len(capital_words) - - if self._comparison_relation == _COMPARISON_RELATION[0]: - return capital_words < self._frequency - else: - return capital_words >= self._frequency - - -class QuotationChecker(Instruction): - """Checks response is wrapped with double quotation marks.""" - - def build_description(self): - """Build the instruction description.""" - self._description_pattern = "Wrap your entire response with double quotation marks." - return self._description_pattern - - def get_instruction_args(self): - """Returns the keyword args of build description.""" - return None - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return [] - - def check_following(self, value): - """Checks if the response is wrapped with double quotation marks.""" - quotations_map = { - "ja": "「」", - "ru": "«»", - "th": "“”", - "zh": "“”", - "zh-cn": "“”", - "zh-tw": "“”", - } - value = value.strip() - lang = get_langid(value) - quotes = quotations_map.get(lang, '""') - # TODO: We may wanna revisit this logic in new generations to only check of the response language's quotes. - return len(value) > 1 and value[0] in [quotes[0], '"'] and value[-1] in [quotes[1], '"'] - - -# Define instruction dicts -_KEYWORD = "keywords:" -_LANGUAGE = "language:" -_LENGTH = "length_constraints:" -_CONTENT = "detectable_content:" -_FORMAT = "detectable_format:" -_MULTITURN = "multi-turn:" -_COMBINATION = "combination:" -_STARTEND = "startend:" -_CHANGE_CASES = "change_case:" -_PUNCTUATION = "punctuation:" - -INSTRUCTION_DICT = { - _KEYWORD + "existence": KeywordChecker, - _KEYWORD + "frequency": KeywordFrequencyChecker, - # _KEYWORD + "key_sentences": KeySentenceChecker, - _KEYWORD + "forbidden_words": ForbiddenWords, - _KEYWORD + "letter_frequency": LetterFrequencyChecker, - _LANGUAGE + "response_language": ResponseLanguageChecker, - _LENGTH + "number_sentences": NumberOfSentences, - _LENGTH + "number_paragraphs": ParagraphChecker, - _LENGTH + "number_words": NumberOfWords, - _LENGTH + "nth_paragraph_first_word": ParagraphFirstWordCheck, - _CONTENT + "number_placeholders": PlaceholderChecker, - _CONTENT + "postscript": PostscriptChecker, - _FORMAT + "number_bullet_lists": BulletListChecker, - # _CONTENT + "rephrase_paragraph": RephraseParagraph, - _FORMAT + "constrained_response": ConstrainedResponseChecker, - _FORMAT + "number_highlighted_sections": (HighlightSectionChecker), - _FORMAT + "multiple_sections": SectionChecker, - # _FORMAT + "rephrase": RephraseChecker, - _FORMAT + "json_format": JsonFormat, - _FORMAT + "title": TitleChecker, - # _MULTITURN + "constrained_start": ConstrainedStartChecker, - _COMBINATION + "two_responses": TwoResponsesChecker, - _COMBINATION + "repeat_prompt": RepeatPromptThenAnswer, - _STARTEND + "end_checker": EndChecker, - _CHANGE_CASES + "capital_word_frequency": CapitalWordFrequencyChecker, - _CHANGE_CASES + "english_capital": CapitalLettersEnglishChecker, - _CHANGE_CASES + "english_lowercase": LowercaseLettersEnglishChecker, - _PUNCTUATION + "no_comma": CommaChecker, - _STARTEND + "quotation": QuotationChecker, -} - -INSTRUCTION_LIST = list(INSTRUCTION_DICT.keys()) + [ - _KEYWORD[:-1], - _LANGUAGE[:-1], - _LENGTH[:-1], - _CONTENT[:-1], - _FORMAT[:-1], - _MULTITURN[:-1], - _COMBINATION[:-1], - _STARTEND[:-1], - _CHANGE_CASES[:-1], - _PUNCTUATION[:-1], -] diff --git a/llama_stack/providers/inline/scoring/basic/utils/math_utils.py b/llama_stack/providers/inline/scoring/basic/utils/math_utils.py deleted file mode 100644 index e11fc625b..000000000 --- a/llama_stack/providers/inline/scoring/basic/utils/math_utils.py +++ /dev/null @@ -1,330 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import re -from typing import Sequence - -from llama_stack.providers.utils.scoring.basic_scoring_utils import time_limit - -# from minerva -SUBSTITUTIONS = [ - ("an ", ""), - ("a ", ""), - (".$", "$"), - ("\\$", ""), - (r"\ ", ""), - (" ", ""), - ("mbox", "text"), - (",\\text{and}", ","), - ("\\text{and}", ","), - ("\\text{m}", "\\text{}"), -] - -REMOVED_EXPRESSIONS = [ - "square", - "ways", - "integers", - "dollars", - "mph", - "inches", - "ft", - "hours", - "km", - "units", - "\\ldots", - "sue", - "points", - "feet", - "minutes", - "digits", - "cents", - "degrees", - "cm", - "gm", - "pounds", - "meters", - "meals", - "edges", - "students", - "childrentickets", - "multiples", - "\\text{s}", - "\\text{.}", - "\\text{\ns}", - "\\text{}^2", - "\\text{}^3", - "\\text{\n}", - "\\text{}", - r"\mathrm{th}", - r"^\circ", - r"^{\circ}", - r"\;", - r",\!", - "{,}", - '"', - "\\dots", -] - - -def try_evaluate_frac(expression: str, fmt: str = "0.2e") -> str: - if isinstance(expression, float): - return expression - new_expression = f"{expression}" - regex = re.compile(r"\\frac{([^}]+)}{([^}]+)}") - for match in re.finditer(regex, expression): - try: - value = float(match.group(1)) / float(match.group(2)) - new_expression = new_expression.replace( - match.group(), - f"{{value:{fmt}}}".format(value=value), - 1, - ) - except Exception: - continue - return new_expression - - -def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str: - try: - with time_limit(seconds=5): - from sympy.parsing.latex import parse_latex - - value = parse_latex(expression).evalf() # type: ignore - return f"{{value:{fmt}}}".format(value=value) - except Exception: - return expression - - -def first_answer(text: str, markers: Sequence[str] = ("Q:", "A:")) -> str: - for marker in markers: - text = text.split(marker)[0] - return text - - -def extract_result_from_boxed(answer: str) -> str: - box_start = "\\boxed" - # format is `\\boxed $` or `\\boxed{}`, with potential white spaces framing `` - start = answer.rfind(box_start) - if start < 0: - return "" - answer = answer[start + len(box_start) :].strip() - ends_with_curly = answer.startswith("{") - i = 0 - open_braces = 0 - while i < len(answer): - if answer[i] == "{": - open_braces += 1 - elif answer[i] == "}": - open_braces -= 1 - if open_braces == 0: - if ends_with_curly: - answer = answer[: i + 1].strip() - break - elif answer[i] == "$": - answer = answer[:i].strip() - break - i += 1 - else: - return "" - # remove extra curly braces - while True: - if answer.startswith("{") and answer.endswith("}"): - answer = answer[1:-1].strip() - else: - break - return answer - - -# from minerva paper + _normalise_result from xavierm -def normalize_final_answer(final_answer: str, regex_pattern: str, match_first: bool = True) -> str: - """Extract and normalize a final answer to a quantitative reasoning question.""" - match = re.findall(regex_pattern, final_answer) - extraction: str - if len(match) > 0: - if match_first: - extraction = match[0] - else: - extraction = match[-1] - else: - extraction = extract_result_from_boxed(final_answer) - - if len(extraction) == 0: - return final_answer - else: - final_answer = extraction - final_answer = final_answer.split("=")[-1] - for before, after in SUBSTITUTIONS: - final_answer = final_answer.replace(before, after) - for expr in REMOVED_EXPRESSIONS: - final_answer = final_answer.replace(expr, "") - # Extract answer that is in LaTeX math, is bold, - # is surrounded by a box, etc. - final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) - final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) - final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) - final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) - final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) - # Normalize shorthand TeX: - # \fracab -> \frac{a}{b} - # \frac{abc}{bef} -> \frac{abc}{bef} - # \fracabc -> \frac{a}{b}c - # \sqrta -> \sqrt{a} - # \sqrtab -> sqrt{a}b - final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) - final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) - final_answer = final_answer.replace("$", "") - # Normalize 100,000 -> 100000 - if final_answer.replace(",", "").isdigit(): - final_answer = final_answer.replace(",", "") - # If the final answer is a single letter in parentheses, remove the parentheses - # Example: (a) -> a (but not (ab) -> ab) - if re.match(r"\([a-zA-Z]\)", final_answer): - final_answer = final_answer[1] - return _normalise_result(final_answer) - - -def _normalise_result(string: str) -> str: - # linebreaks - string = string.replace("\n", "") - - # remove inverse spaces - string = string.replace("\\!", "") - - # replace \\ with \ - string = string.replace("\\\\", "\\") - - # replace tfrac and dfrac with frac - string = string.replace("cfrac", "frac") - string = string.replace("tfrac", "frac") - string = string.replace("dfrac", "frac") - - # remove \left and \right - string = string.replace("\\left", "") - string = string.replace("\\le", "") - string = string.replace("\\right", "") - - # Remove circ (degrees) - string = string.replace("^{\\circ}", "") - string = string.replace("^\\circ", "") - - # remove dollar signs - string = string.replace("\\$", "") - - # remove units (on the right) - string = _remove_right_units(string) - - # remove percentage - string = string.replace("\\%", "") - string = string.replace(r"\%", "") - - # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string - string = string.replace(" .", " 0.") - string = string.replace("{.", "{0.") - # if empty, return empty string - if len(string) == 0: - return string - if string[0] == ".": - string = "0" + string - - # to consider: get rid of e.g. "k = " or "q = " at beginning - string = string.split("=")[-1] - - # fix sqrt3 --> sqrt{3} - string = _fix_sqrt(string) - - # remove spaces - string = string.replace(" ", "") - - # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} - string = _fix_fracs(string) - - # manually change 0.5 --> \frac{1}{2} - if string == "0.5": - string = "\\frac{1}{2}" - - # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y - string = _fix_a_slash_b(string) - - return string - - -def _remove_right_units(string: str) -> str: - # "\\text{ " only ever occurs (at least in the val set) when describing units - try: - if "\\text{ " in string: - splits = string.split("\\text{ ") - assert len(splits) == 2 - return splits[0] - else: - return string - except AssertionError: - return string - - -def _fix_sqrt(string: str) -> str: - if "\\sqrt" not in string: - return string - splits = string.split("\\sqrt") - new_string = splits[0] - for split in splits[1:]: - if len(split) == 0: - return string - if split[0] != "{": - a = split[0] - new_substr = "\\sqrt{" + a + "}" + split[1:] - else: - new_substr = "\\sqrt" + split - new_string += new_substr - return new_string - - -def _fix_fracs(string: str) -> str: - substrs = string.split("\\frac") - new_str = substrs[0] - if len(substrs) > 1: - substrs = substrs[1:] - for substr in substrs: - new_str += "\\frac" - if len(substr) == 0: - return string - if substr[0] == "{": - new_str += substr - else: - try: - assert len(substr) >= 2 - except AssertionError: - return string - a = substr[0] - b = substr[1] - if b != "{": - if len(substr) > 2: - post_substr = substr[2:] - new_str += "{" + a + "}{" + b + "}" + post_substr - else: - new_str += "{" + a + "}{" + b + "}" - else: - if len(substr) > 2: - post_substr = substr[2:] - new_str += "{" + a + "}" + b + post_substr - else: - new_str += "{" + a + "}" + b - string = new_str - return string - - -def _fix_a_slash_b(string: str) -> str: - if len(string.split("/")) != 2: - return string - a = string.split("/")[0] - b = string.split("/")[1] - try: - ia = int(a) - ib = int(b) - assert string == "{}/{}".format(ia, ib) - new_string = "\\frac{" + str(ia) + "}{" + str(ib) + "}" - return new_string - except (ValueError, AssertionError): - return string diff --git a/llama_stack/providers/inline/scoring/braintrust/__init__.py b/llama_stack/providers/inline/scoring/braintrust/__init__.py deleted file mode 100644 index f1b0112d9..000000000 --- a/llama_stack/providers/inline/scoring/braintrust/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from typing import Any, Dict - -from pydantic import BaseModel - -from llama_stack.distribution.datatypes import Api - -from .config import BraintrustScoringConfig - - -class BraintrustProviderDataValidator(BaseModel): - openai_api_key: str - - -async def get_provider_impl( - config: BraintrustScoringConfig, - deps: Dict[Api, Any], -): - from .braintrust import BraintrustScoringImpl - - impl = BraintrustScoringImpl(config, deps[Api.datasetio], deps[Api.datasets]) - await impl.initialize() - return impl diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py deleted file mode 100644 index 3fae83340..000000000 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -import os -from typing import Any, Dict, List, Optional - -from autoevals.llm import Factuality -from autoevals.ragas import ( - AnswerCorrectness, - AnswerRelevancy, - AnswerSimilarity, - ContextEntityRecall, - ContextPrecision, - ContextRecall, - ContextRelevancy, - Faithfulness, -) -from pydantic import BaseModel - -from llama_stack.apis.datasetio import DatasetIO -from llama_stack.apis.datasets import Datasets -from llama_stack.apis.scoring import ( - ScoreBatchResponse, - ScoreResponse, - Scoring, - ScoringResult, - ScoringResultRow, -) -from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams -from llama_stack.distribution.datatypes import Api -from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate -from llama_stack.providers.utils.common.data_schema_validator import ( - get_valid_schemas, - validate_dataset_schema, - validate_row_schema, -) -from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics - -from .config import BraintrustScoringConfig -from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def -from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def -from .scoring_fn.fn_defs.answer_similarity import answer_similarity_fn_def -from .scoring_fn.fn_defs.context_entity_recall import context_entity_recall_fn_def -from .scoring_fn.fn_defs.context_precision import context_precision_fn_def -from .scoring_fn.fn_defs.context_recall import context_recall_fn_def -from .scoring_fn.fn_defs.context_relevancy import context_relevancy_fn_def -from .scoring_fn.fn_defs.factuality import factuality_fn_def -from .scoring_fn.fn_defs.faithfulness import faithfulness_fn_def - - -class BraintrustScoringFnEntry(BaseModel): - identifier: str - evaluator: Any - fn_def: ScoringFn - - -SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY = [ - BraintrustScoringFnEntry( - identifier="braintrust::factuality", - evaluator=Factuality(), - fn_def=factuality_fn_def, - ), - BraintrustScoringFnEntry( - identifier="braintrust::answer-correctness", - evaluator=AnswerCorrectness(), - fn_def=answer_correctness_fn_def, - ), - BraintrustScoringFnEntry( - identifier="braintrust::answer-relevancy", - evaluator=AnswerRelevancy(), - fn_def=answer_relevancy_fn_def, - ), - BraintrustScoringFnEntry( - identifier="braintrust::answer-similarity", - evaluator=AnswerSimilarity(), - fn_def=answer_similarity_fn_def, - ), - BraintrustScoringFnEntry( - identifier="braintrust::faithfulness", - evaluator=Faithfulness(), - fn_def=faithfulness_fn_def, - ), - BraintrustScoringFnEntry( - identifier="braintrust::context-entity-recall", - evaluator=ContextEntityRecall(), - fn_def=context_entity_recall_fn_def, - ), - BraintrustScoringFnEntry( - identifier="braintrust::context-precision", - evaluator=ContextPrecision(), - fn_def=context_precision_fn_def, - ), - BraintrustScoringFnEntry( - identifier="braintrust::context-recall", - evaluator=ContextRecall(), - fn_def=context_recall_fn_def, - ), - BraintrustScoringFnEntry( - identifier="braintrust::context-relevancy", - evaluator=ContextRelevancy(), - fn_def=context_relevancy_fn_def, - ), -] - - -class BraintrustScoringImpl( - Scoring, - ScoringFunctionsProtocolPrivate, - NeedsRequestProviderData, -): - def __init__( - self, - config: BraintrustScoringConfig, - datasetio_api: DatasetIO, - datasets_api: Datasets, - ) -> None: - self.config = config - self.datasetio_api = datasetio_api - self.datasets_api = datasets_api - - self.braintrust_evaluators = { - entry.identifier: entry.evaluator for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY - } - self.supported_fn_defs_registry = { - entry.identifier: entry.fn_def for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY - } - - async def initialize(self) -> None: ... - - async def shutdown(self) -> None: ... - - async def list_scoring_functions(self) -> List[ScoringFn]: - scoring_fn_defs_list = list(self.supported_fn_defs_registry.values()) - for f in scoring_fn_defs_list: - assert f.identifier.startswith("braintrust"), ( - "All braintrust scoring fn must have identifier prefixed with 'braintrust'! " - ) - - return scoring_fn_defs_list - - async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: - raise NotImplementedError("Registering scoring function not allowed for braintrust provider") - - async def set_api_key(self) -> None: - # api key is in the request headers - if not self.config.openai_api_key: - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.openai_api_key: - raise ValueError( - 'Pass OpenAI API Key in the header X-LlamaStack-Provider-Data as { "openai_api_key": }' - ) - self.config.openai_api_key = provider_data.openai_api_key - - os.environ["OPENAI_API_KEY"] = self.config.openai_api_key - - async def score_batch( - self, - dataset_id: str, - scoring_functions: Dict[str, Optional[ScoringFnParams]], - save_results_dataset: bool = False, - ) -> ScoreBatchResponse: - await self.set_api_key() - - dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) - - all_rows = await self.datasetio_api.iterrows( - dataset_id=dataset_id, - limit=-1, - ) - res = await self.score(input_rows=all_rows.data, scoring_functions=scoring_functions) - if save_results_dataset: - # TODO: persist and register dataset on to server for reading - # self.datasets_api.register_dataset() - raise NotImplementedError("Save results dataset not implemented yet") - - return ScoreBatchResponse( - results=res.results, - ) - - async def score_row( - self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None - ) -> ScoringResultRow: - validate_row_schema(input_row, get_valid_schemas(Api.scoring.value)) - await self.set_api_key() - assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None" - expected_answer = input_row["expected_answer"] - generated_answer = input_row["generated_answer"] - input_query = input_row["input_query"] - evaluator = self.braintrust_evaluators[scoring_fn_identifier] - - result = evaluator( - generated_answer, - expected_answer, - input=input_query, - context=input_row["context"] if "context" in input_row else None, - ) - score = result.score - return {"score": score, "metadata": result.metadata} - - async def score( - self, - input_rows: List[Dict[str, Any]], - scoring_functions: Dict[str, Optional[ScoringFnParams]], - ) -> ScoreResponse: - await self.set_api_key() - res = {} - for scoring_fn_id in scoring_functions: - if scoring_fn_id not in self.supported_fn_defs_registry: - raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") - - score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in input_rows] - aggregation_functions = self.supported_fn_defs_registry[scoring_fn_id].params.aggregation_functions - - # override scoring_fn params if provided - if scoring_functions[scoring_fn_id] is not None: - override_params = scoring_functions[scoring_fn_id] - if override_params.aggregation_functions: - aggregation_functions = override_params.aggregation_functions - - agg_results = aggregate_metrics(score_results, aggregation_functions) - res[scoring_fn_id] = ScoringResult( - score_rows=score_results, - aggregated_results=agg_results, - ) - - return ScoreResponse( - results=res, - ) diff --git a/llama_stack/providers/inline/scoring/braintrust/config.py b/llama_stack/providers/inline/scoring/braintrust/config.py deleted file mode 100644 index d4e0d9bcd..000000000 --- a/llama_stack/providers/inline/scoring/braintrust/config.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from typing import Any, Dict, Optional - -from pydantic import BaseModel, Field - - -class BraintrustScoringConfig(BaseModel): - openai_api_key: Optional[str] = Field( - default=None, - description="The OpenAI API Key", - ) - - @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: - return { - "openai_api_key": "${env.OPENAI_API_KEY:}", - } diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/__init__.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/__init__.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py deleted file mode 100644 index 4fe07f822..000000000 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - BasicScoringFnParams, - ScoringFn, -) - -answer_correctness_fn_def = ScoringFn( - identifier="braintrust::answer-correctness", - description=( - "Scores the correctness of the answer based on the ground truth. " - "Uses Braintrust LLM-based scorer from autoevals library." - ), - provider_id="braintrust", - provider_resource_id="answer-correctness", - return_type=NumberType(), - params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]), -) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py deleted file mode 100644 index a1995cc4e..000000000 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - BasicScoringFnParams, - ScoringFn, -) - -answer_relevancy_fn_def = ScoringFn( - identifier="braintrust::answer-relevancy", - description=( - "Test output relevancy against the input query using Braintrust LLM scorer. " - "See: github.com/braintrustdata/autoevals" - ), - provider_id="braintrust", - provider_resource_id="answer-relevancy", - return_type=NumberType(), - params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]), -) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_similarity.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_similarity.py deleted file mode 100644 index e8fe15259..000000000 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_similarity.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - BasicScoringFnParams, - ScoringFn, -) - -answer_similarity_fn_def = ScoringFn( - identifier="braintrust::answer-similarity", - description=( - "Test output similarity against expected value using Braintrust LLM scorer. " - "See: github.com/braintrustdata/autoevals" - ), - provider_id="braintrust", - provider_resource_id="answer-similarity", - return_type=NumberType(), - params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]), -) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_entity_recall.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_entity_recall.py deleted file mode 100644 index d9b129a8b..000000000 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_entity_recall.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - BasicScoringFnParams, - ScoringFn, -) - -context_entity_recall_fn_def = ScoringFn( - identifier="braintrust::context-entity-recall", - description=( - "Evaluates how well the context captures the named entities present in the " - "reference answer. See: github.com/braintrustdata/autoevals" - ), - provider_id="braintrust", - provider_resource_id="context-entity-recall", - return_type=NumberType(), - params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]), -) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_precision.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_precision.py deleted file mode 100644 index c1d7e855b..000000000 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_precision.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - BasicScoringFnParams, - ScoringFn, -) - -context_precision_fn_def = ScoringFn( - identifier="braintrust::context-precision", - description=( - "Measures how much of the provided context is actually relevant to answering the " - "question. See: github.com/braintrustdata/autoevals" - ), - provider_id="braintrust", - provider_resource_id="context-precision", - return_type=NumberType(), - params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]), -) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_recall.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_recall.py deleted file mode 100644 index 01ddd0dd0..000000000 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_recall.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - BasicScoringFnParams, - ScoringFn, -) - -context_recall_fn_def = ScoringFn( - identifier="braintrust::context-recall", - description=( - "Evaluates how well the context covers the information needed to answer the " - "question. See: github.com/braintrustdata/autoevals" - ), - provider_id="braintrust", - provider_resource_id="context-recall", - return_type=NumberType(), - params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]), -) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_relevancy.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_relevancy.py deleted file mode 100644 index 55d89344a..000000000 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_relevancy.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - BasicScoringFnParams, - ScoringFn, -) - -context_relevancy_fn_def = ScoringFn( - identifier="braintrust::context-relevancy", - description=( - "Assesses how relevant the provided context is to the given question. See: github.com/braintrustdata/autoevals" - ), - provider_id="braintrust", - provider_resource_id="context-relevancy", - return_type=NumberType(), - params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]), -) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py deleted file mode 100644 index c621ecf7f..000000000 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - BasicScoringFnParams, - ScoringFn, -) - -factuality_fn_def = ScoringFn( - identifier="braintrust::factuality", - description=( - "Test output factuality against expected value using Braintrust LLM scorer. " - "See: github.com/braintrustdata/autoevals" - ), - provider_id="braintrust", - provider_resource_id="factuality", - return_type=NumberType(), - params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]), -) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/faithfulness.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/faithfulness.py deleted file mode 100644 index 2e85c0c7c..000000000 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/faithfulness.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - BasicScoringFnParams, - ScoringFn, -) - -faithfulness_fn_def = ScoringFn( - identifier="braintrust::faithfulness", - description=( - "Test output faithfulness to the input query using Braintrust LLM scorer. " - "See: github.com/braintrustdata/autoevals" - ), - provider_id="braintrust", - provider_resource_id="faithfulness", - return_type=NumberType(), - params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]), -) diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py b/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py deleted file mode 100644 index 4a83bfe13..000000000 --- a/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from typing import Any, Dict - -from llama_stack.distribution.datatypes import Api - -from .config import LlmAsJudgeScoringConfig - - -async def get_provider_impl( - config: LlmAsJudgeScoringConfig, - deps: Dict[Api, Any], -): - from .scoring import LlmAsJudgeScoringImpl - - impl = LlmAsJudgeScoringImpl(config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference]) - await impl.initialize() - return impl diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/config.py b/llama_stack/providers/inline/scoring/llm_as_judge/config.py deleted file mode 100644 index ff63fc5e7..000000000 --- a/llama_stack/providers/inline/scoring/llm_as_judge/config.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from typing import Any, Dict - -from pydantic import BaseModel - - -class LlmAsJudgeScoringConfig(BaseModel): - @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: - return {} diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py deleted file mode 100644 index 7f004fbb6..000000000 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from typing import Any, Dict, List, Optional - -from llama_stack.apis.datasetio import DatasetIO -from llama_stack.apis.datasets import Datasets -from llama_stack.apis.inference.inference import Inference -from llama_stack.apis.scoring import ( - ScoreBatchResponse, - ScoreResponse, - Scoring, - ScoringResult, -) -from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams -from llama_stack.distribution.datatypes import Api -from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate -from llama_stack.providers.utils.common.data_schema_validator import ( - get_valid_schemas, - validate_dataset_schema, -) - -from .config import LlmAsJudgeScoringConfig -from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn - -LLM_JUDGE_FN = LlmAsJudgeScoringFn - - -class LlmAsJudgeScoringImpl( - Scoring, - ScoringFunctionsProtocolPrivate, -): - def __init__( - self, - config: LlmAsJudgeScoringConfig, - datasetio_api: DatasetIO, - datasets_api: Datasets, - inference_api: Inference, - ) -> None: - self.config = config - self.datasetio_api = datasetio_api - self.datasets_api = datasets_api - self.inference_api = inference_api - - async def initialize(self) -> None: - impl = LLM_JUDGE_FN(inference_api=self.inference_api) - self.llm_as_judge_fn = impl - - async def shutdown(self) -> None: ... - - async def list_scoring_functions(self) -> List[ScoringFn]: - scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs() - - for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs(): - assert f.identifier.startswith("llm-as-judge"), ( - "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! " - ) - - return scoring_fn_defs_list - - async def register_scoring_function(self, function_def: ScoringFn) -> None: - self.llm_as_judge_fn.register_scoring_fn_def(function_def) - - async def score_batch( - self, - dataset_id: str, - scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, - save_results_dataset: bool = False, - ) -> ScoreBatchResponse: - dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)) - - all_rows = await self.datasetio_api.iterrows( - dataset_id=dataset_id, - limit=-1, - ) - res = await self.score( - input_rows=all_rows.data, - scoring_functions=scoring_functions, - ) - if save_results_dataset: - # TODO: persist and register dataset on to server for reading - # self.datasets_api.register_dataset() - raise NotImplementedError("Save results dataset not implemented yet") - - return ScoreBatchResponse( - results=res.results, - ) - - async def score( - self, - input_rows: List[Dict[str, Any]], - scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, - ) -> ScoreResponse: - res = {} - for scoring_fn_id in scoring_functions.keys(): - scoring_fn = self.llm_as_judge_fn - scoring_fn_params = scoring_functions.get(scoring_fn_id, None) - score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params) - agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params) - res[scoring_fn_id] = ScoringResult( - score_rows=score_results, - aggregated_results=agg_results, - ) - - return ScoreResponse( - results=res, - ) diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/__init__.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/__init__.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_405b_simpleqa.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_405b_simpleqa.py deleted file mode 100644 index 074f1ff46..000000000 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_405b_simpleqa.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - LLMAsJudgeScoringFnParams, - ScoringFn, -) - -GRADER_TEMPLATE = """ -Your job is to look at a question, a gold target, and a predicted answer, and then assign a grade of either ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"]. -First, I will give examples of each grade, and then you will grade a new example. -The following are examples of CORRECT predicted answers. -``` -Question: What are the names of Barack Obama's children? -Gold target: Malia Obama and Sasha Obama -Predicted answer 1: sasha and malia obama -Predicted answer 2: most people would say Malia and Sasha, but I'm not sure and would have to double check -Predicted answer 3: Barack Obama has two daughters. Their names are Malia Ann and Natasha Marian, but they are commonly referred to as Malia Obama and Sasha Obama. Malia was born on July 4, 1998, and Sasha was born on June 10, 2001. -``` -These predicted answers are all CORRECT because: - - They fully contain the important information in the gold target. - - They do not contain any information that contradicts the gold target. - - Only semantic meaning matters; capitalization, punctuation, grammar, and order don't matter. - - Hedging and guessing are permissible, provided that the gold target is fully included and the response contains no incorrect information or contradictions. -The following are examples of INCORRECT predicted answers. -``` -Question: What are the names of Barack Obama's children? -Gold target: Malia and Sasha -Predicted answer 1: Malia. -Predicted answer 2: Malia, Sasha, and Susan. -Predicted answer 3: Barack Obama does not have any children. -Predicted answer 4: I think it's either Malia and Sasha. Or it could be Malia and Jackie. Or it could be Joey and Malia. -Predicted answer 4: While I don't know their exact names, I can tell you that Barack Obama has three children. -Predicted answer 5: It's possible you may mean Betsy and Olivia. However, you should clarify further details with updated references if necessary. Is that the correct answer? -Predicted answer 6: It may be the case that Obama's child is named James. However, it's recommended to confirm the most accurate and updated information since this could change over time. This model may not always reflect the most current information. -``` -These predicted answers are all INCORRECT because: - - A factual statement in the answer contradicts the gold target. Incorrect statements that have some hedging (e.g., "it is possible that", "although i'm not sure, i think") are also considered incorrect. -The following are examples of NOT_ATTEMPTED predicted answers. -``` -Question: What are the names of Barack Obama's children? -Gold target: Malia and Sasha -Predicted answer 1: I don't know. -Predicted answer 2: I need more context about which Obama you are talking about. -Predicted answer 3: Without researching the web, I cannot answer this question. However, I can tell you that Barack Obama has two children. -Predicted answer 4: Barack Obama has two children. I know that one of them is Malia, but I'm not sure about the other one. -``` -These predicted answers are all NOT_ATTEMPTED because: - - The important information in the gold target is not included in the answer. - - No statements in the answer contradict the gold target. -Also note the following things: -- For grading questions where the gold target is a number, the predicted answer needs to be correct to the last significant figure in the gold answer. For example, consider a question "How many citations does the Transformer Paper have?" with gold target "120k". - - Predicted answers "120k", "124k", and 115k" are all CORRECT. - - Predicted answers "100k" and "113k" are INCORRECT. - - Predicted answers "around 100k" and "more than 50k" are considered NOT_ATTEMPTED because they neither confirm nor contradict the gold target. -- The gold target may contain more information than the question. In such cases, the predicted answer only needs to contain the information that is in the question. - - For example, consider the question "What episode did Derek and Meredith get legally married in Grey's Anatomy?" with gold target "Season 7, Episode 20: White Wedding". Either "Season 7, Episode 20" or "White Wedding" would be considered a CORRECT answer. -- Do not punish predicted answers if they omit information that would be clearly inferred from the question. - - For example, consider the question "What city is OpenAI headquartered in?" and the gold target "San Francisco, California". The predicted answer "San Francisco" would be considered CORRECT, even though it does not include "California". - - Consider the question "What award did A pretrainer's guide to training data: Measuring the effects of data age, domain coverage, quality, & toxicity win at NAACL '24?", the gold target is "Outstanding Paper Award". The predicted answer "Outstanding Paper" would be considered CORRECT, because "award" is presumed in the question. - - For the question "What is the height of Jason Wei in meters?", the gold target is "1.73 m". The predicted answer "1.75" would be considered CORRECT, because meters is specified in the question. - - For the question "What is the name of Barack Obama's wife?", the gold target is "Michelle Obama". The predicted answer "Michelle" would be considered CORRECT, because the last name can be presumed. -- Do not punish for typos in people's name if it's clearly the same name. - - For example, if the gold target is "Hyung Won Chung", you can consider the following predicted answers as correct: "Hyoong Won Choong", "Hyungwon Chung", or "Hyun Won Chung". -Here is a new example. Simply reply with either CORRECT, INCORRECT, NOT ATTEMPTED. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer. -``` -Question: {input_query} -Gold target: {expected_answer} -Predicted answer: {generated_answer} -``` -Grade the predicted answer of this new question as one of: -A: CORRECT -B: INCORRECT -C: NOT_ATTEMPTED -Just return the letters "A", "B", or "C", with no text around it. -""".strip() - - -llm_as_judge_405b_simpleqa = ScoringFn( - identifier="llm-as-judge::405b-simpleqa", - description="Llm As Judge Scoring Function for SimpleQA Benchmark (https://github.com/openai/simple-evals/blob/main/simpleqa_eval.py)", - return_type=NumberType(), - provider_id="llm-as-judge", - provider_resource_id="llm-as-judge-405b-simpleqa", - params=LLMAsJudgeScoringFnParams( - judge_model="meta-llama/Llama-3.1-405B-Instruct", - prompt_template=GRADER_TEMPLATE, - judge_score_regexes=[r"(A|B|C)"], - aggregation_functions=[AggregationFunctionType.categorical_count.value], - ), -) diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py deleted file mode 100644 index 205e0bbf3..000000000 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn - -llm_as_judge_base = ScoringFn( - identifier="llm-as-judge::base", - description="Llm As Judge Scoring Function", - return_type=NumberType(), - provider_id="llm-as-judge", - provider_resource_id="llm-as-judge-base", - params=LLMAsJudgeScoringFnParams( - judge_model="meta-llama/Llama-3.1-405B-Instruct", - prompt_template="Enter custom LLM as Judge Prompt Template", - ), -) diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py deleted file mode 100644 index f4e8ab0aa..000000000 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -import re -from typing import Any, Dict, Optional - -from llama_stack.apis.inference.inference import Inference, UserMessage -from llama_stack.apis.scoring import ScoringResultRow -from llama_stack.apis.scoring_functions import ScoringFnParams -from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn - -from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa -from .fn_defs.llm_as_judge_base import llm_as_judge_base - - -class LlmAsJudgeScoringFn(RegisteredBaseScoringFn): - """ - A scoring_fn that assigns - """ - - def __init__(self, inference_api: Inference, *arg, **kwargs) -> None: - super().__init__(*arg, **kwargs) - self.inference_api = inference_api - self.supported_fn_defs_registry = { - llm_as_judge_base.identifier: llm_as_judge_base, - llm_as_judge_405b_simpleqa.identifier: llm_as_judge_405b_simpleqa, - } - - async def score_row( - self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, - ) -> ScoringResultRow: - assert scoring_fn_identifier is not None, "Scoring function identifier not found." - fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] - - # override params if scoring_params is provided - if scoring_params is not None: - fn_def.params = scoring_params - - assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}." - assert fn_def.params.prompt_template is not None, "LLM Judge prompt_template not found." - assert fn_def.params.judge_score_regexes is not None, "LLM Judge judge_score_regexes not found." - - input_query = input_row["input_query"] - expected_answer = input_row["expected_answer"] - generated_answer = input_row["generated_answer"] - - judge_input_msg = fn_def.params.prompt_template.format( - input_query=input_query, - expected_answer=expected_answer, - generated_answer=generated_answer, - ) - - judge_response = await self.inference_api.chat_completion( - model_id=fn_def.params.judge_model, - messages=[ - UserMessage( - content=judge_input_msg, - ), - ], - ) - content = judge_response.completion_message.content - rating_regexes = fn_def.params.judge_score_regexes - - judge_rating = None - for regex in rating_regexes: - match = re.search(regex, content) - if match: - judge_rating = match.group(1) - break - - return { - "score": judge_rating, - "judge_feedback": content, - } diff --git a/llama_stack/templates/open-benchmark/__init__.py b/llama_stack/templates/open-benchmark/__init__.py deleted file mode 100644 index 14d0a28f5..000000000 --- a/llama_stack/templates/open-benchmark/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .open_benchmark import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/open-benchmark/build.yaml b/llama_stack/templates/open-benchmark/build.yaml deleted file mode 100644 index b39a17820..000000000 --- a/llama_stack/templates/open-benchmark/build.yaml +++ /dev/null @@ -1,30 +0,0 @@ -version: '2' -distribution_spec: - description: Distribution for running open benchmarks - providers: - inference: - - remote::openai - - remote::anthropic - - remote::gemini - - remote::groq - - remote::together - vector_io: - - inline::sqlite-vec - - remote::chromadb - - remote::pgvector - safety: - - inline::llama-guard - agents: - - inline::meta-reference - telemetry: - - inline::meta-reference - datasetio: - - remote::huggingface - - inline::localfs - tool_runtime: - - remote::brave-search - - remote::tavily-search - - inline::code-interpreter - - inline::rag-runtime - - remote::model-context-protocol -image_type: conda diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py deleted file mode 100644 index cfa9135cf..000000000 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Dict, List, Tuple - -from llama_stack.apis.datasets import DatasetPurpose, URIDataSource -from llama_stack.apis.models.models import ModelType -from llama_stack.distribution.datatypes import ( - DatasetInput, - ModelInput, - Provider, - ShieldInput, - ToolGroupInput, -) -from llama_stack.providers.inline.vector_io.sqlite_vec.config import ( - SQLiteVectorIOConfig, -) -from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig -from llama_stack.providers.remote.inference.gemini.config import GeminiConfig -from llama_stack.providers.remote.inference.groq.config import GroqConfig -from llama_stack.providers.remote.inference.openai.config import OpenAIConfig -from llama_stack.providers.remote.inference.together.config import TogetherImplConfig -from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig -from llama_stack.providers.remote.vector_io.pgvector.config import ( - PGVectorVectorIOConfig, -) -from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry -from llama_stack.templates.template import ( - DistributionTemplate, - RunConfigSettings, - get_model_registry, -) - - -def get_inference_providers() -> Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]: - # in this template, we allow each API key to be optional - providers = [ - ( - "openai", - [ - ProviderModelEntry( - provider_model_id="openai/gpt-4o", - model_type=ModelType.llm, - ) - ], - OpenAIConfig.sample_run_config(api_key="${env.OPENAI_API_KEY:}"), - ), - ( - "anthropic", - [ - ProviderModelEntry( - provider_model_id="anthropic/claude-3-5-sonnet-latest", - model_type=ModelType.llm, - ) - ], - AnthropicConfig.sample_run_config(api_key="${env.ANTHROPIC_API_KEY:}"), - ), - ( - "gemini", - [ - ProviderModelEntry( - provider_model_id="gemini/gemini-1.5-flash", - model_type=ModelType.llm, - ) - ], - GeminiConfig.sample_run_config(api_key="${env.GEMINI_API_KEY:}"), - ), - ( - "groq", - [], - GroqConfig.sample_run_config(api_key="${env.GROQ_API_KEY:}"), - ), - ( - "together", - [], - TogetherImplConfig.sample_run_config(api_key="${env.TOGETHER_API_KEY:}"), - ), - ] - inference_providers = [] - available_models = {} - for provider_id, model_entries, config in providers: - inference_providers.append( - Provider( - provider_id=provider_id, - provider_type=f"remote::{provider_id}", - config=config, - ) - ) - available_models[provider_id] = model_entries - return inference_providers, available_models - - -def get_distribution_template() -> DistributionTemplate: - inference_providers, available_models = get_inference_providers() - providers = { - "inference": [p.provider_type for p in inference_providers], - "vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"], - "safety": ["inline::llama-guard"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], - "datasetio": ["remote::huggingface", "inline::localfs"], - "tool_runtime": [ - "remote::brave-search", - "remote::tavily-search", - "inline::code-interpreter", - "inline::rag-runtime", - "remote::model-context-protocol", - ], - } - name = "open-benchmark" - - vector_io_providers = [ - Provider( - provider_id="sqlite-vec", - provider_type="inline::sqlite-vec", - config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), - ), - Provider( - provider_id="${env.ENABLE_CHROMADB+chromadb}", - provider_type="remote::chromadb", - config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"), - ), - Provider( - provider_id="${env.ENABLE_PGVECTOR+pgvector}", - provider_type="remote::pgvector", - config=PGVectorVectorIOConfig.sample_run_config( - db="${env.PGVECTOR_DB:}", - user="${env.PGVECTOR_USER:}", - password="${env.PGVECTOR_PASSWORD:}", - ), - ), - ] - - default_tool_groups = [ - ToolGroupInput( - toolgroup_id="builtin::websearch", - provider_id="tavily-search", - ), - ToolGroupInput( - toolgroup_id="builtin::rag", - provider_id="rag-runtime", - ), - ToolGroupInput( - toolgroup_id="builtin::code_interpreter", - provider_id="code-interpreter", - ), - ] - - default_models = get_model_registry(available_models) + [ - ModelInput( - model_id="meta-llama/Llama-3.3-70B-Instruct", - provider_id="groq", - provider_model_id="groq/llama-3.3-70b-versatile", - model_type=ModelType.llm, - ), - ModelInput( - model_id="meta-llama/Llama-3.1-405B-Instruct", - provider_id="together", - provider_model_id="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", - model_type=ModelType.llm, - ), - ] - - default_datasets = [ - DatasetInput( - dataset_id="simpleqa", - purpose=DatasetPurpose.eval_messages_answer, - source=URIDataSource( - uri="huggingface://datasets/llamastack/simpleqa?split=train", - ), - ), - DatasetInput( - dataset_id="mmlu_cot", - purpose=DatasetPurpose.eval_messages_answer, - source=URIDataSource( - uri="huggingface://datasets/llamastack/mmlu_cot?split=test&name=all", - ), - ), - DatasetInput( - dataset_id="gpqa_cot", - purpose=DatasetPurpose.eval_messages_answer, - source=URIDataSource( - uri="huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main", - ), - ), - DatasetInput( - dataset_id="math_500", - purpose=DatasetPurpose.eval_messages_answer, - source=URIDataSource( - uri="huggingface://datasets/llamastack/math_500?split=test", - ), - ), - DatasetInput( - dataset_id="bfcl", - purpose=DatasetPurpose.eval_messages_answer, - source=URIDataSource( - uri="huggingface://datasets/llamastack/bfcl_v3?split=train", - ), - ), - DatasetInput( - dataset_id="ifeval", - purpose=DatasetPurpose.eval_messages_answer, - source=URIDataSource( - uri="huggingface://datasets/llamastack/IfEval?split=train", - ), - ), - DatasetInput( - dataset_id="docvqa", - purpose=DatasetPurpose.eval_messages_answer, - source=URIDataSource( - uri="huggingface://datasets/llamastack/docvqa?split=val", - ), - ), - ] - - # TODO(xiyan): fix this back as registerable resources - # default_benchmarks = [ - # BenchmarkInput( - # benchmark_id="meta-reference-simpleqa", - # dataset_id="simpleqa", - # grader_ids=["llm-as-judge::405b-simpleqa"], - # ), - # BenchmarkInput( - # benchmark_id="meta-reference-mmlu-cot", - # dataset_id="mmlu_cot", - # grader_ids=["basic::regex_parser_multiple_choice_answer"], - # ), - # BenchmarkInput( - # benchmark_id="meta-reference-gpqa-cot", - # dataset_id="gpqa_cot", - # grader_ids=["basic::regex_parser_multiple_choice_answer"], - # ), - # BenchmarkInput( - # benchmark_id="meta-reference-math-500", - # dataset_id="math_500", - # grader_ids=["basic::regex_parser_math_response"], - # ), - # BenchmarkInput( - # benchmark_id="meta-reference-bfcl", - # dataset_id="bfcl", - # grader_ids=["basic::bfcl"], - # ), - # BenchmarkInput( - # benchmark_id="meta-reference-ifeval", - # dataset_id="ifeval", - # grader_ids=["basic::ifeval"], - # ), - # BenchmarkInput( - # benchmark_id="meta-reference-docvqa", - # dataset_id="docvqa", - # grader_ids=["basic::docvqa"], - # ), - # ] - - return DistributionTemplate( - name=name, - distro_type="self_hosted", - description="Distribution for running open benchmarks", - container_image=None, - template_path=None, - providers=providers, - available_models_by_provider=available_models, - run_configs={ - "run.yaml": RunConfigSettings( - provider_overrides={ - "inference": inference_providers, - "vector_io": vector_io_providers, - }, - default_models=default_models, - default_tool_groups=default_tool_groups, - default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], - default_datasets=default_datasets, - ), - }, - run_config_env_vars={ - "LLAMA_STACK_PORT": ( - "8321", - "Port for the Llama Stack distribution server", - ), - "TOGETHER_API_KEY": ( - "", - "Together API Key", - ), - "OPENAI_API_KEY": ( - "", - "OpenAI API Key", - ), - "GEMINI_API_KEY": ( - "", - "Gemini API Key", - ), - "ANTHROPIC_API_KEY": ( - "", - "Anthropic API Key", - ), - "GROQ_API_KEY": ( - "", - "Groq API Key", - ), - }, - ) diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml deleted file mode 100644 index d9ca11a84..000000000 --- a/llama_stack/templates/open-benchmark/run.yaml +++ /dev/null @@ -1,190 +0,0 @@ -version: '2' -image_name: open-benchmark -apis: -- agents -- datasetio -- inference -- safety -- telemetry -- tool_runtime -- vector_io -providers: - inference: - - provider_id: openai - provider_type: remote::openai - config: - api_key: ${env.OPENAI_API_KEY:} - - provider_id: anthropic - provider_type: remote::anthropic - config: - api_key: ${env.ANTHROPIC_API_KEY:} - - provider_id: gemini - provider_type: remote::gemini - config: - api_key: ${env.GEMINI_API_KEY:} - - provider_id: groq - provider_type: remote::groq - config: - url: https://api.groq.com - api_key: ${env.GROQ_API_KEY:} - - provider_id: together - provider_type: remote::together - config: - url: https://api.together.xyz/v1 - api_key: ${env.TOGETHER_API_KEY:} - vector_io: - - provider_id: sqlite-vec - provider_type: inline::sqlite-vec - config: - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/sqlite_vec.db - - provider_id: ${env.ENABLE_CHROMADB+chromadb} - provider_type: remote::chromadb - config: - url: ${env.CHROMADB_URL:} - - provider_id: ${env.ENABLE_PGVECTOR+pgvector} - provider_type: remote::pgvector - config: - host: ${env.PGVECTOR_HOST:localhost} - port: ${env.PGVECTOR_PORT:5432} - db: ${env.PGVECTOR_DB:} - user: ${env.PGVECTOR_USER:} - password: ${env.PGVECTOR_PASSWORD:} - safety: - - provider_id: llama-guard - provider_type: inline::llama-guard - config: - excluded_categories: [] - agents: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - persistence_store: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/agents_store.db - telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - sinks: ${env.TELEMETRY_SINKS:console,sqlite} - sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/open-benchmark/trace_store.db} - datasetio: - - provider_id: huggingface - provider_type: remote::huggingface - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/huggingface_datasetio.db - - provider_id: localfs - provider_type: inline::localfs - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/localfs_datasetio.db - tool_runtime: - - provider_id: brave-search - provider_type: remote::brave-search - config: - api_key: ${env.BRAVE_SEARCH_API_KEY:} - max_results: 3 - - provider_id: tavily-search - provider_type: remote::tavily-search - config: - api_key: ${env.TAVILY_SEARCH_API_KEY:} - max_results: 3 - - provider_id: code-interpreter - provider_type: inline::code-interpreter - config: {} - - provider_id: rag-runtime - provider_type: inline::rag-runtime - config: {} - - provider_id: model-context-protocol - provider_type: remote::model-context-protocol - config: {} -metadata_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/registry.db -models: -- metadata: {} - model_id: openai/gpt-4o - provider_id: openai - provider_model_id: openai/gpt-4o - model_type: llm -- metadata: {} - model_id: anthropic/claude-3-5-sonnet-latest - provider_id: anthropic - provider_model_id: anthropic/claude-3-5-sonnet-latest - model_type: llm -- metadata: {} - model_id: gemini/gemini-1.5-flash - provider_id: gemini - provider_model_id: gemini/gemini-1.5-flash - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.3-70B-Instruct - provider_id: groq - provider_model_id: groq/llama-3.3-70b-versatile - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.1-405B-Instruct - provider_id: together - provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo - model_type: llm -shields: -- shield_id: meta-llama/Llama-Guard-3-8B -vector_dbs: [] -datasets: -- purpose: eval/messages-answer - source: - type: uri - uri: huggingface://datasets/llamastack/simpleqa?split=train - metadata: {} - dataset_id: simpleqa -- purpose: eval/messages-answer - source: - type: uri - uri: huggingface://datasets/llamastack/mmlu_cot?split=test&name=all - metadata: {} - dataset_id: mmlu_cot -- purpose: eval/messages-answer - source: - type: uri - uri: huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main - metadata: {} - dataset_id: gpqa_cot -- purpose: eval/messages-answer - source: - type: uri - uri: huggingface://datasets/llamastack/math_500?split=test - metadata: {} - dataset_id: math_500 -- purpose: eval/messages-answer - source: - type: uri - uri: huggingface://datasets/llamastack/bfcl_v3?split=train - metadata: {} - dataset_id: bfcl -- purpose: eval/messages-answer - source: - type: uri - uri: huggingface://datasets/llamastack/IfEval?split=train - metadata: {} - dataset_id: ifeval -- purpose: eval/messages-answer - source: - type: uri - uri: huggingface://datasets/llamastack/docvqa?split=val - metadata: {} - dataset_id: docvqa -benchmarks: [] -tool_groups: -- toolgroup_id: builtin::websearch - provider_id: tavily-search -- toolgroup_id: builtin::rag - provider_id: rag-runtime -- toolgroup_id: builtin::code_interpreter - provider_id: code-interpreter -server: - port: 8321