forked from phoenix-oss/llama-stack-mirror
feat: [new open benchmark] DocVQA (#1647)
# What does this PR do? DocVQA asks model to look a a picture, then answer a question given in text, with a text answer by text information in the picture. these questions often require understanding of relative positions of texts within the picture. original dataset is defined in the "Task1" of https://www.docvqa.org/datasets ## Test Plan setup llama server with ``` llama stack run ./llama_stack/templates/open-benchmark/run.yaml ``` then send traffic: ``` llama-stack-client eval run-benchmark "meta-reference-docvqa" --model-id meta-llama/Llama-3.3-70B-Instruct --output-dir /tmp/gpqa --num-examples 200 ```
This commit is contained in:
parent
1902e5754c
commit
d117bfe597
6 changed files with 287 additions and 1 deletions
|
@ -23,6 +23,7 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
||||||
|
|
||||||
from .config import BasicScoringConfig
|
from .config import BasicScoringConfig
|
||||||
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
||||||
|
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
|
||||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||||
from .scoring_fn.regex_parser_math_response_scoring_fn import (
|
from .scoring_fn.regex_parser_math_response_scoring_fn import (
|
||||||
RegexParserMathResponseScoringFn,
|
RegexParserMathResponseScoringFn,
|
||||||
|
@ -36,6 +37,7 @@ FIXED_FNS = [
|
||||||
RegexParserScoringFn,
|
RegexParserScoringFn,
|
||||||
RegexParserMathResponseScoringFn,
|
RegexParserMathResponseScoringFn,
|
||||||
BFCLScoringFn,
|
BFCLScoringFn,
|
||||||
|
DocVQAScoringFn,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,240 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||||
|
|
||||||
|
from .fn_defs.docvqa import docvqa
|
||||||
|
|
||||||
|
CONTRACTIONS = {
|
||||||
|
"aint": "ain't",
|
||||||
|
"arent": "aren't",
|
||||||
|
"cant": "can't",
|
||||||
|
"couldve": "could've",
|
||||||
|
"couldnt": "couldn't",
|
||||||
|
"couldn'tve": "couldn't've",
|
||||||
|
"couldnt've": "couldn't've",
|
||||||
|
"didnt": "didn't",
|
||||||
|
"doesnt": "doesn't",
|
||||||
|
"dont": "don't",
|
||||||
|
"hadnt": "hadn't",
|
||||||
|
"hadnt've": "hadn't've",
|
||||||
|
"hadn'tve": "hadn't've",
|
||||||
|
"hasnt": "hasn't",
|
||||||
|
"havent": "haven't",
|
||||||
|
"hed": "he'd",
|
||||||
|
"hed've": "he'd've",
|
||||||
|
"he'dve": "he'd've",
|
||||||
|
"hes": "he's",
|
||||||
|
"howd": "how'd",
|
||||||
|
"howll": "how'll",
|
||||||
|
"hows": "how's",
|
||||||
|
"Id've": "I'd've",
|
||||||
|
"I'dve": "I'd've",
|
||||||
|
"Im": "I'm",
|
||||||
|
"Ive": "I've",
|
||||||
|
"isnt": "isn't",
|
||||||
|
"itd": "it'd",
|
||||||
|
"itd've": "it'd've",
|
||||||
|
"it'dve": "it'd've",
|
||||||
|
"itll": "it'll",
|
||||||
|
"let's": "let's",
|
||||||
|
"maam": "ma'am",
|
||||||
|
"mightnt": "mightn't",
|
||||||
|
"mightnt've": "mightn't've",
|
||||||
|
"mightn'tve": "mightn't've",
|
||||||
|
"mightve": "might've",
|
||||||
|
"mustnt": "mustn't",
|
||||||
|
"mustve": "must've",
|
||||||
|
"neednt": "needn't",
|
||||||
|
"notve": "not've",
|
||||||
|
"oclock": "o'clock",
|
||||||
|
"oughtnt": "oughtn't",
|
||||||
|
"ow's'at": "'ow's'at",
|
||||||
|
"'ows'at": "'ow's'at",
|
||||||
|
"'ow'sat": "'ow's'at",
|
||||||
|
"shant": "shan't",
|
||||||
|
"shed've": "she'd've",
|
||||||
|
"she'dve": "she'd've",
|
||||||
|
"she's": "she's",
|
||||||
|
"shouldve": "should've",
|
||||||
|
"shouldnt": "shouldn't",
|
||||||
|
"shouldnt've": "shouldn't've",
|
||||||
|
"shouldn'tve": "shouldn't've",
|
||||||
|
"somebody'd": "somebodyd",
|
||||||
|
"somebodyd've": "somebody'd've",
|
||||||
|
"somebody'dve": "somebody'd've",
|
||||||
|
"somebodyll": "somebody'll",
|
||||||
|
"somebodys": "somebody's",
|
||||||
|
"someoned": "someone'd",
|
||||||
|
"someoned've": "someone'd've",
|
||||||
|
"someone'dve": "someone'd've",
|
||||||
|
"someonell": "someone'll",
|
||||||
|
"someones": "someone's",
|
||||||
|
"somethingd": "something'd",
|
||||||
|
"somethingd've": "something'd've",
|
||||||
|
"something'dve": "something'd've",
|
||||||
|
"somethingll": "something'll",
|
||||||
|
"thats": "that's",
|
||||||
|
"thered": "there'd",
|
||||||
|
"thered've": "there'd've",
|
||||||
|
"there'dve": "there'd've",
|
||||||
|
"therere": "there're",
|
||||||
|
"theres": "there's",
|
||||||
|
"theyd": "they'd",
|
||||||
|
"theyd've": "they'd've",
|
||||||
|
"they'dve": "they'd've",
|
||||||
|
"theyll": "they'll",
|
||||||
|
"theyre": "they're",
|
||||||
|
"theyve": "they've",
|
||||||
|
"twas": "'twas",
|
||||||
|
"wasnt": "wasn't",
|
||||||
|
"wed've": "we'd've",
|
||||||
|
"we'dve": "we'd've",
|
||||||
|
"weve": "we've",
|
||||||
|
"werent": "weren't",
|
||||||
|
"whatll": "what'll",
|
||||||
|
"whatre": "what're",
|
||||||
|
"whats": "what's",
|
||||||
|
"whatve": "what've",
|
||||||
|
"whens": "when's",
|
||||||
|
"whered": "where'd",
|
||||||
|
"wheres": "where's",
|
||||||
|
"whereve": "where've",
|
||||||
|
"whod": "who'd",
|
||||||
|
"whod've": "who'd've",
|
||||||
|
"who'dve": "who'd've",
|
||||||
|
"wholl": "who'll",
|
||||||
|
"whos": "who's",
|
||||||
|
"whove": "who've",
|
||||||
|
"whyll": "why'll",
|
||||||
|
"whyre": "why're",
|
||||||
|
"whys": "why's",
|
||||||
|
"wont": "won't",
|
||||||
|
"wouldve": "would've",
|
||||||
|
"wouldnt": "wouldn't",
|
||||||
|
"wouldnt've": "wouldn't've",
|
||||||
|
"wouldn'tve": "wouldn't've",
|
||||||
|
"yall": "y'all",
|
||||||
|
"yall'll": "y'all'll",
|
||||||
|
"y'allll": "y'all'll",
|
||||||
|
"yall'd've": "y'all'd've",
|
||||||
|
"y'alld've": "y'all'd've",
|
||||||
|
"y'all'dve": "y'all'd've",
|
||||||
|
"youd": "you'd",
|
||||||
|
"youd've": "you'd've",
|
||||||
|
"you'dve": "you'd've",
|
||||||
|
"youll": "you'll",
|
||||||
|
"youre": "you're",
|
||||||
|
"youve": "you've",
|
||||||
|
"1st": "first",
|
||||||
|
"2nd": "second",
|
||||||
|
"3rd": "third",
|
||||||
|
}
|
||||||
|
NUMBERS = {
|
||||||
|
"none": "0",
|
||||||
|
"zero": "0",
|
||||||
|
"one": "1",
|
||||||
|
"two": "2",
|
||||||
|
"three": "3",
|
||||||
|
"four": "4",
|
||||||
|
"five": "5",
|
||||||
|
"six": "6",
|
||||||
|
"seven": "7",
|
||||||
|
"eight": "8",
|
||||||
|
"nine": "9",
|
||||||
|
"ten": "10",
|
||||||
|
}
|
||||||
|
ARTICLES = [
|
||||||
|
"a",
|
||||||
|
"an",
|
||||||
|
"the",
|
||||||
|
"to",
|
||||||
|
"in",
|
||||||
|
"from",
|
||||||
|
"by",
|
||||||
|
] # Contains a bit more than just articles, but we want to get rid of these elements influencing the accuracy
|
||||||
|
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
|
||||||
|
COMMA_STRIP = re.compile(r"(\d)(\,)(\d)")
|
||||||
|
PUNCTUATION = [
|
||||||
|
";",
|
||||||
|
r"/",
|
||||||
|
"[",
|
||||||
|
"]",
|
||||||
|
'"',
|
||||||
|
"{",
|
||||||
|
"}",
|
||||||
|
"(",
|
||||||
|
")",
|
||||||
|
"=",
|
||||||
|
"+",
|
||||||
|
"\\",
|
||||||
|
"_",
|
||||||
|
"-",
|
||||||
|
">",
|
||||||
|
"<",
|
||||||
|
"@",
|
||||||
|
"`",
|
||||||
|
",",
|
||||||
|
"?",
|
||||||
|
"!",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_answer(s: str) -> str:
|
||||||
|
# process punctuation
|
||||||
|
for p in PUNCTUATION:
|
||||||
|
if (p + " " in s or " " + p in s) or (re.search(COMMA_STRIP, s) is not None):
|
||||||
|
s = s.replace(p, "")
|
||||||
|
else:
|
||||||
|
s = s.replace(p, " ")
|
||||||
|
s = PERIOD_STRIP.sub("", s, re.UNICODE)
|
||||||
|
|
||||||
|
# process digits and articles
|
||||||
|
temp_text = s.lower().split()
|
||||||
|
out_text = []
|
||||||
|
for word in temp_text:
|
||||||
|
word = NUMBERS.setdefault(word, word)
|
||||||
|
if word not in ARTICLES:
|
||||||
|
out_text.append(word)
|
||||||
|
|
||||||
|
# standardize contractions
|
||||||
|
for word_id, word in enumerate(out_text):
|
||||||
|
if word in CONTRACTIONS:
|
||||||
|
out_text[word_id] = CONTRACTIONS[word]
|
||||||
|
return " ".join(out_text)
|
||||||
|
|
||||||
|
|
||||||
|
class DocVQAScoringFn(RegisteredBaseScoringFn):
|
||||||
|
"""
|
||||||
|
docvqa basically matches the generated answer against several allowed
|
||||||
|
choices, but we need to normalize the answer to avoid penalizing
|
||||||
|
trivial differences
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.supported_fn_defs_registry = {
|
||||||
|
docvqa.identifier: docvqa,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def score_row(
|
||||||
|
self,
|
||||||
|
input_row: Dict[str, Any],
|
||||||
|
scoring_fn_identifier: Optional[str] = "docvqa",
|
||||||
|
scoring_params: Optional[ScoringFnParams] = None,
|
||||||
|
) -> ScoringResultRow:
|
||||||
|
expected_answers = json.loads(input_row["expected_answer"])
|
||||||
|
generated_answer = input_row["generated_answer"]
|
||||||
|
score = 1.0 if normalize_answer(generated_answer) in [normalize_answer(s) for s in expected_answers] else 0.0
|
||||||
|
return {
|
||||||
|
"score": score,
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
AggregationFunctionType,
|
||||||
|
BasicScoringFnParams,
|
||||||
|
ScoringFn,
|
||||||
|
)
|
||||||
|
|
||||||
|
docvqa = ScoringFn(
|
||||||
|
identifier="basic::docvqa",
|
||||||
|
description="DocVQA Visual Question & Answer scoring function",
|
||||||
|
return_type=NumberType(),
|
||||||
|
provider_id="basic",
|
||||||
|
provider_resource_id="docvqa",
|
||||||
|
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||||
|
)
|
|
@ -203,6 +203,13 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
uri="huggingface://datasets/llamastack/bfcl_v3?split=train",
|
uri="huggingface://datasets/llamastack/bfcl_v3?split=train",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
DatasetInput(
|
||||||
|
dataset_id="docvqa",
|
||||||
|
purpose=DatasetPurpose.eval_messages_answer,
|
||||||
|
source=URIDataSource(
|
||||||
|
uri="huggingface://datasets/llamastack/docvqa?split=val",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
default_benchmarks = [
|
default_benchmarks = [
|
||||||
|
@ -231,6 +238,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
dataset_id="bfcl",
|
dataset_id="bfcl",
|
||||||
scoring_functions=["basic::bfcl"],
|
scoring_functions=["basic::bfcl"],
|
||||||
),
|
),
|
||||||
|
BenchmarkInput(
|
||||||
|
benchmark_id="meta-reference-docvqa",
|
||||||
|
dataset_id="docvqa",
|
||||||
|
scoring_functions=["basic::docvqa"],
|
||||||
|
),
|
||||||
]
|
]
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name=name,
|
name=name,
|
||||||
|
|
|
@ -188,6 +188,12 @@ datasets:
|
||||||
uri: huggingface://datasets/llamastack/bfcl_v3?split=train
|
uri: huggingface://datasets/llamastack/bfcl_v3?split=train
|
||||||
metadata: {}
|
metadata: {}
|
||||||
dataset_id: bfcl
|
dataset_id: bfcl
|
||||||
|
- purpose: eval/messages-answer
|
||||||
|
source:
|
||||||
|
type: uri
|
||||||
|
uri: huggingface://datasets/llamastack/docvqa?split=val
|
||||||
|
metadata: {}
|
||||||
|
dataset_id: docvqa
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
benchmarks:
|
benchmarks:
|
||||||
- dataset_id: simpleqa
|
- dataset_id: simpleqa
|
||||||
|
@ -215,6 +221,11 @@ benchmarks:
|
||||||
- basic::bfcl
|
- basic::bfcl
|
||||||
metadata: {}
|
metadata: {}
|
||||||
benchmark_id: meta-reference-bfcl
|
benchmark_id: meta-reference-bfcl
|
||||||
|
- dataset_id: docvqa
|
||||||
|
scoring_functions:
|
||||||
|
- basic::docvqa
|
||||||
|
metadata: {}
|
||||||
|
benchmark_id: meta-reference-docvqa
|
||||||
tool_groups:
|
tool_groups:
|
||||||
- toolgroup_id: builtin::websearch
|
- toolgroup_id: builtin::websearch
|
||||||
provider_id: tavily-search
|
provider_id: tavily-search
|
||||||
|
|
|
@ -188,7 +188,7 @@ def test_chat_completion_doesnt_block_event_loop(caplog):
|
||||||
caplog.set_level(logging.WARNING)
|
caplog.set_level(logging.WARNING)
|
||||||
|
|
||||||
# Log when event loop is blocked for more than 200ms
|
# Log when event loop is blocked for more than 200ms
|
||||||
loop.slow_callback_duration = 0.2
|
loop.slow_callback_duration = 0.5
|
||||||
# Sleep for 500ms in our delayed http response
|
# Sleep for 500ms in our delayed http response
|
||||||
sleep_time = 0.5
|
sleep_time = 0.5
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue