From d117bfe59772e2b93c002a0fdbe21ba2cb174a97 Mon Sep 17 00:00:00 2001 From: yyymeta <123776235+yyymeta@users.noreply.github.com> Date: Wed, 19 Mar 2025 14:56:14 -0700 Subject: [PATCH] feat: [new open benchmark] DocVQA (#1647) # What does this PR do? DocVQA asks model to look a a picture, then answer a question given in text, with a text answer by text information in the picture. these questions often require understanding of relative positions of texts within the picture. original dataset is defined in the "Task1" of https://www.docvqa.org/datasets ## Test Plan setup llama server with ``` llama stack run ./llama_stack/templates/open-benchmark/run.yaml ``` then send traffic: ``` llama-stack-client eval run-benchmark "meta-reference-docvqa" --model-id meta-llama/Llama-3.3-70B-Instruct --output-dir /tmp/gpqa --num-examples 200 ``` --- .../providers/inline/scoring/basic/scoring.py | 2 + .../basic/scoring_fn/docvqa_scoring_fn.py | 240 ++++++++++++++++++ .../basic/scoring_fn/fn_defs/docvqa.py | 21 ++ .../open-benchmark/open_benchmark.py | 12 + llama_stack/templates/open-benchmark/run.yaml | 11 + .../providers/inference/test_remote_vllm.py | 2 +- 6 files changed, 287 insertions(+), 1 deletion(-) create mode 100644 llama_stack/providers/inline/scoring/basic/scoring_fn/docvqa_scoring_fn.py create mode 100644 llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/docvqa.py diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index a735166e1..095d46cf5 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -23,6 +23,7 @@ from llama_stack.providers.utils.common.data_schema_validator import ( from .config import BasicScoringConfig from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn +from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.regex_parser_math_response_scoring_fn import ( RegexParserMathResponseScoringFn, @@ -36,6 +37,7 @@ FIXED_FNS = [ RegexParserScoringFn, RegexParserMathResponseScoringFn, BFCLScoringFn, + DocVQAScoringFn, ] diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/docvqa_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/docvqa_scoring_fn.py new file mode 100644 index 000000000..84ca55732 --- /dev/null +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/docvqa_scoring_fn.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import re +from typing import Any, Dict, Optional + +from llama_stack.apis.scoring import ScoringResultRow +from llama_stack.apis.scoring_functions import ScoringFnParams +from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn + +from .fn_defs.docvqa import docvqa + +CONTRACTIONS = { + "aint": "ain't", + "arent": "aren't", + "cant": "can't", + "couldve": "could've", + "couldnt": "couldn't", + "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", + "didnt": "didn't", + "doesnt": "doesn't", + "dont": "don't", + "hadnt": "hadn't", + "hadnt've": "hadn't've", + "hadn'tve": "hadn't've", + "hasnt": "hasn't", + "havent": "haven't", + "hed": "he'd", + "hed've": "he'd've", + "he'dve": "he'd've", + "hes": "he's", + "howd": "how'd", + "howll": "how'll", + "hows": "how's", + "Id've": "I'd've", + "I'dve": "I'd've", + "Im": "I'm", + "Ive": "I've", + "isnt": "isn't", + "itd": "it'd", + "itd've": "it'd've", + "it'dve": "it'd've", + "itll": "it'll", + "let's": "let's", + "maam": "ma'am", + "mightnt": "mightn't", + "mightnt've": "mightn't've", + "mightn'tve": "mightn't've", + "mightve": "might've", + "mustnt": "mustn't", + "mustve": "must've", + "neednt": "needn't", + "notve": "not've", + "oclock": "o'clock", + "oughtnt": "oughtn't", + "ow's'at": "'ow's'at", + "'ows'at": "'ow's'at", + "'ow'sat": "'ow's'at", + "shant": "shan't", + "shed've": "she'd've", + "she'dve": "she'd've", + "she's": "she's", + "shouldve": "should've", + "shouldnt": "shouldn't", + "shouldnt've": "shouldn't've", + "shouldn'tve": "shouldn't've", + "somebody'd": "somebodyd", + "somebodyd've": "somebody'd've", + "somebody'dve": "somebody'd've", + "somebodyll": "somebody'll", + "somebodys": "somebody's", + "someoned": "someone'd", + "someoned've": "someone'd've", + "someone'dve": "someone'd've", + "someonell": "someone'll", + "someones": "someone's", + "somethingd": "something'd", + "somethingd've": "something'd've", + "something'dve": "something'd've", + "somethingll": "something'll", + "thats": "that's", + "thered": "there'd", + "thered've": "there'd've", + "there'dve": "there'd've", + "therere": "there're", + "theres": "there's", + "theyd": "they'd", + "theyd've": "they'd've", + "they'dve": "they'd've", + "theyll": "they'll", + "theyre": "they're", + "theyve": "they've", + "twas": "'twas", + "wasnt": "wasn't", + "wed've": "we'd've", + "we'dve": "we'd've", + "weve": "we've", + "werent": "weren't", + "whatll": "what'll", + "whatre": "what're", + "whats": "what's", + "whatve": "what've", + "whens": "when's", + "whered": "where'd", + "wheres": "where's", + "whereve": "where've", + "whod": "who'd", + "whod've": "who'd've", + "who'dve": "who'd've", + "wholl": "who'll", + "whos": "who's", + "whove": "who've", + "whyll": "why'll", + "whyre": "why're", + "whys": "why's", + "wont": "won't", + "wouldve": "would've", + "wouldnt": "wouldn't", + "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", + "yall": "y'all", + "yall'll": "y'all'll", + "y'allll": "y'all'll", + "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", + "y'all'dve": "y'all'd've", + "youd": "you'd", + "youd've": "you'd've", + "you'dve": "you'd've", + "youll": "you'll", + "youre": "you're", + "youve": "you've", + "1st": "first", + "2nd": "second", + "3rd": "third", +} +NUMBERS = { + "none": "0", + "zero": "0", + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "eight": "8", + "nine": "9", + "ten": "10", +} +ARTICLES = [ + "a", + "an", + "the", + "to", + "in", + "from", + "by", +] # Contains a bit more than just articles, but we want to get rid of these elements influencing the accuracy +PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)") +COMMA_STRIP = re.compile(r"(\d)(\,)(\d)") +PUNCTUATION = [ + ";", + r"/", + "[", + "]", + '"', + "{", + "}", + "(", + ")", + "=", + "+", + "\\", + "_", + "-", + ">", + "<", + "@", + "`", + ",", + "?", + "!", +] + + +def normalize_answer(s: str) -> str: + # process punctuation + for p in PUNCTUATION: + if (p + " " in s or " " + p in s) or (re.search(COMMA_STRIP, s) is not None): + s = s.replace(p, "") + else: + s = s.replace(p, " ") + s = PERIOD_STRIP.sub("", s, re.UNICODE) + + # process digits and articles + temp_text = s.lower().split() + out_text = [] + for word in temp_text: + word = NUMBERS.setdefault(word, word) + if word not in ARTICLES: + out_text.append(word) + + # standardize contractions + for word_id, word in enumerate(out_text): + if word in CONTRACTIONS: + out_text[word_id] = CONTRACTIONS[word] + return " ".join(out_text) + + +class DocVQAScoringFn(RegisteredBaseScoringFn): + """ + docvqa basically matches the generated answer against several allowed + choices, but we need to normalize the answer to avoid penalizing + trivial differences + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.supported_fn_defs_registry = { + docvqa.identifier: docvqa, + } + + async def score_row( + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = "docvqa", + scoring_params: Optional[ScoringFnParams] = None, + ) -> ScoringResultRow: + expected_answers = json.loads(input_row["expected_answer"]) + generated_answer = input_row["generated_answer"] + score = 1.0 if normalize_answer(generated_answer) in [normalize_answer(s) for s in expected_answers] else 0.0 + return { + "score": score, + } diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/docvqa.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/docvqa.py new file mode 100644 index 000000000..aad3dfe26 --- /dev/null +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/docvqa.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) + +docvqa = ScoringFn( + identifier="basic::docvqa", + description="DocVQA Visual Question & Answer scoring function", + return_type=NumberType(), + provider_id="basic", + provider_resource_id="docvqa", + params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]), +) diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index acfbd78d6..d1c27e901 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -203,6 +203,13 @@ def get_distribution_template() -> DistributionTemplate: uri="huggingface://datasets/llamastack/bfcl_v3?split=train", ), ), + DatasetInput( + dataset_id="docvqa", + purpose=DatasetPurpose.eval_messages_answer, + source=URIDataSource( + uri="huggingface://datasets/llamastack/docvqa?split=val", + ), + ), ] default_benchmarks = [ @@ -231,6 +238,11 @@ def get_distribution_template() -> DistributionTemplate: dataset_id="bfcl", scoring_functions=["basic::bfcl"], ), + BenchmarkInput( + benchmark_id="meta-reference-docvqa", + dataset_id="docvqa", + scoring_functions=["basic::docvqa"], + ), ] return DistributionTemplate( name=name, diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 8dbf51472..80a517fe8 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -188,6 +188,12 @@ datasets: uri: huggingface://datasets/llamastack/bfcl_v3?split=train metadata: {} dataset_id: bfcl +- purpose: eval/messages-answer + source: + type: uri + uri: huggingface://datasets/llamastack/docvqa?split=val + metadata: {} + dataset_id: docvqa scoring_fns: [] benchmarks: - dataset_id: simpleqa @@ -215,6 +221,11 @@ benchmarks: - basic::bfcl metadata: {} benchmark_id: meta-reference-bfcl +- dataset_id: docvqa + scoring_functions: + - basic::docvqa + metadata: {} + benchmark_id: meta-reference-docvqa tool_groups: - toolgroup_id: builtin::websearch provider_id: tavily-search diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index cb0997e1a..9c2281d85 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -188,7 +188,7 @@ def test_chat_completion_doesnt_block_event_loop(caplog): caplog.set_level(logging.WARNING) # Log when event loop is blocked for more than 200ms - loop.slow_callback_duration = 0.2 + loop.slow_callback_duration = 0.5 # Sleep for 500ms in our delayed http response sleep_time = 0.5