llm as judge, move folders

This commit is contained in:
Xi Yan 2024-10-25 16:41:36 -07:00
parent bf8bc7a781
commit 16620a8185
13 changed files with 173 additions and 66 deletions

View file

@ -2,3 +2,4 @@ include requirements.txt
include llama_stack/distribution/*.sh include llama_stack/distribution/*.sh
include llama_stack/cli/scripts/*.sh include llama_stack/cli/scripts/*.sh
include llama_stack/templates/*/build.yaml include llama_stack/templates/*/build.yaml
include llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/*.json

View file

@ -97,6 +97,7 @@ class InferenceRouter(Inference):
logprobs=logprobs, logprobs=logprobs,
) )
provider = self.routing_table.get_provider_impl(model) provider = self.routing_table.get_provider_impl(model)
if stream: if stream:
return (chunk async for chunk in await provider.chat_completion(**params)) return (chunk async for chunk in await provider.chat_completion(**params))
else: else:

View file

@ -16,6 +16,8 @@ async def get_provider_impl(
): ):
from .scoring import MetaReferenceScoringImpl from .scoring import MetaReferenceScoringImpl
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets]) impl = MetaReferenceScoringImpl(
config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference]
)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -11,24 +11,25 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.inference.inference import Inference
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import ( from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import (
EqualityScoringFn, EqualityScoringFn,
) )
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import (
LlmAsJudgeScoringFn,
)
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import ( from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import (
SubsetOfScoringFn, SubsetOfScoringFn,
) )
from .config import MetaReferenceScoringConfig from .config import MetaReferenceScoringConfig
SUPPORTED_SCORING_FNS = [ FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn]
EqualityScoringFn,
SubsetOfScoringFn,
]
SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORING_FNS} LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
@ -37,17 +38,34 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
config: MetaReferenceScoringConfig, config: MetaReferenceScoringConfig,
datasetio_api: DatasetIO, datasetio_api: DatasetIO,
datasets_api: Datasets, datasets_api: Datasets,
inference_api: Inference,
) -> None: ) -> None:
self.config = config self.config = config
self.datasetio_api = datasetio_api self.datasetio_api = datasetio_api
self.datasets_api = datasets_api self.datasets_api = datasets_api
self.inference_api = inference_api
self.scoring_fn_id_impls = {}
async def initialize(self) -> None: ... async def initialize(self) -> None:
for x in FIXED_FNS:
impl = x()
await impl.initialize()
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
for x in LLM_JUDGE_FNS:
impl = x(inference_api=self.inference_api)
await impl.initialize()
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
async def shutdown(self) -> None: ... async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFnDef]: async def list_scoring_functions(self) -> List[ScoringFnDef]:
return [x.scoring_function_def for x in SUPPORTED_SCORING_FNS] return [
fn_defs
for impl in self.scoring_fn_id_impls.values()
for fn_defs in impl.get_supported_scoring_fn_defs()
]
async def register_scoring_function(self, function_def: ScoringFnDef) -> None: async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
raise NotImplementedError( raise NotImplementedError(
@ -99,9 +117,9 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
) -> ScoreResponse: ) -> ScoreResponse:
res = {} res = {}
for scoring_fn_id in scoring_functions: for scoring_fn_id in scoring_functions:
if scoring_fn_id not in SCORER_REGISTRY: if scoring_fn_id not in self.scoring_fn_id_impls:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = SCORER_REGISTRY[scoring_fn_id]() scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
score_results = await scoring_fn.score(input_rows, scoring_fn_id) score_results = await scoring_fn.score(input_rows, scoring_fn_id)
agg_results = await scoring_fn.aggregate(score_results) agg_results = await scoring_fn.aggregate(score_results)
res[scoring_fn_id] = ScoringResult( res[scoring_fn_id] = ScoringResult(

View file

@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Dict, List
from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403
import json
class BaseScoringFn(ABC): class BaseScoringFn(ABC):
@ -17,14 +18,30 @@ class BaseScoringFn(ABC):
- aggregate(self, scoring_fn_results) - aggregate(self, scoring_fn_results)
""" """
scoring_function_def: ScoringFnDef
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {}
self.defs_paths = []
def __str__(self) -> str: def __str__(self) -> str:
return self.__class__.__name__ return self.__class__.__name__
async def initialize(self) -> None:
for f in self.defs_paths:
with open(f, "r") as f:
scoring_fn_def = ScoringFnDef(**json.load(f))
self.register_scoring_fn_def(scoring_fn_def)
def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]:
return [x for x in self.supported_fn_defs_registry.values()]
def register_scoring_fn_def(self, scoring_fn_def: ScoringFnDef) -> None:
if scoring_fn_def.identifier in self.supported_fn_defs_registry:
raise ValueError(
f"Scoring function def with identifier {scoring_fn_def.identifier} already exists."
)
self.supported_fn_defs_registry[scoring_fn_def.identifier] = scoring_fn_def
@abstractmethod @abstractmethod
async def score_row( async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None

View file

@ -3,10 +3,13 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pathlib import Path
from typing import Any, Dict, List from typing import Any, Dict, List
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
FN_DEFS_PATH = Path(__file__).parent / "fn_defs"
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
num_correct = sum(result["score"] for result in scoring_results) num_correct = sum(result["score"] for result in scoring_results)
@ -17,3 +20,12 @@ def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any
"num_correct": num_correct, "num_correct": num_correct,
"num_total": len(scoring_results), "num_total": len(scoring_results),
} }
def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
return {
"average": sum(
result["score"] for result in scoring_results if result["score"] is not None
)
/ len([_ for _ in scoring_results if _["score"] is not None]),
}

View file

@ -10,8 +10,10 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_
from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
aggregate_accuracy, aggregate_accuracy,
FN_DEFS_PATH,
) )
@ -20,12 +22,9 @@ class EqualityScoringFn(BaseScoringFn):
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise. A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
""" """
scoring_function_def = ScoringFnDef( def __init__(self, *args, **kwargs) -> None:
identifier="equality", super().__init__(*args, **kwargs)
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", self.defs_paths = [FN_DEFS_PATH / "equality.json"]
parameters=[],
return_type=NumberType(),
)
async def score_row( async def score_row(
self, self,

View file

@ -0,0 +1,10 @@
{
"identifier": "meta-reference::equality",
"description": "Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
"metadata": {},
"parameters": [],
"return_type": {
"type": "number"
},
"context": null
}

View file

@ -0,0 +1,13 @@
{
"identifier": "meta-reference::llm_as_judge_8b_correctness",
"description": "Llm As Judge Scoring Function",
"metadata": {},
"parameters": [],
"return_type": {
"type": "number"
},
"context": {
"judge_model": "Llama3.1-8B-Instruct",
"prompt_template": "\nYou will be given a question, a expected_answer, and a system_answer.\nYour task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.\nGive your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.\nProvide your feedback as follows:\nFeedback:::\nTotal rating: (your rating, as a int between 0 and 5)\nNow here are the question, expected_answer, system_answer.\nQuestion: {input_query}\nExpected Answer: {expected_answer}\nSystem Answer: {generated_answer}\nFeedback:::\nTotal rating:\n"
}
}

View file

@ -0,0 +1,10 @@
{
"identifier": "meta-reference::subset_of",
"description": "Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
"metadata": {},
"parameters": [],
"return_type": {
"type": "number"
},
"context": null
}

View file

@ -3,31 +3,19 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.inference.inference import Inference
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import ( from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
BaseScoringFn, BaseScoringFn,
) )
from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( import re
aggregate_accuracy,
)
JUDGE_PROMPT = """ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
You will be given a question, a expected_answer, and a system_answer. aggregate_average,
Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question. FN_DEFS_PATH,
Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question. )
Provide your feedback as follows:
Feedback:::
Total rating: (your rating, as a int between 0 and 5)
Now here are the question, expected_answer, system_answer.
Question: {question}
Expected Answer: {expected_answer}
System Answer: {answer}
Feedback:::
Total rating:
"""
class LlmAsJudgeScoringFn(BaseScoringFn): class LlmAsJudgeScoringFn(BaseScoringFn):
@ -35,27 +23,62 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
A scoring_fn that assigns A scoring_fn that assigns
""" """
def __init__(self, *args, **kwargs) -> None: def __init__(self, inference_api: Inference, *arg, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*arg, **kwargs)
self.scoring_fn_def_registry = {} self.inference_api = inference_api
self.defs_paths = [FN_DEFS_PATH / "llm_as_judge_8b_correctness.json"]
def register_scoring_def(self, scoring_fn_def: ScoringFnDef) -> None: async def score_row(
self.scoring_function_def_registry[scoring_fn_def.identifier] = scoring_fn_def self,
input_row: Dict[str, Any],
async def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: scoring_fn_identifier: Optional[str] = None,
assert "expected_answer" in input_row, "Expected answer not found in input row." ) -> ScoringResultRow:
assert ( assert (
"generated_answer" in input_row scoring_fn_identifier is not None
), "Generated answer not found in input row." ), "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
assert (
fn_def.context is not None and fn_def.context.prompt_template is not None
), "LLM Judge prompt_template not found."
input_query = input_row["input_query"]
expected_answer = input_row["expected_answer"] expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"] generated_answer = input_row["generated_answer"]
score = 1.0 if expected_answer == generated_answer else 0.0
judge_input_msg = fn_def.context.prompt_template.format(
input_query=input_query,
expected_answer=expected_answer,
generated_answer=generated_answer,
)
judge_response = await self.inference_api.chat_completion(
model=fn_def.context.judge_model,
messages=[
{
"role": "user",
"content": judge_input_msg,
}
],
)
content = judge_response.completion_message.content
rating_regexs = [
r"Total rating: (\d+)",
r"rating: (\d+)",
r"Rating: (\d+)",
]
judge_rating = None
for regex in rating_regexs:
match = re.search(regex, content)
if match:
judge_rating = int(match.group(1))
break
return { return {
"score": score, "score": judge_rating,
"judge_feedback": content,
} }
async def aggregate( async def aggregate(
self, scoring_results: List[ScoringResultRow] self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results) return aggregate_average(scoring_results)

View file

@ -12,6 +12,7 @@ from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
aggregate_accuracy, aggregate_accuracy,
FN_DEFS_PATH,
) )
@ -20,23 +21,15 @@ class SubsetOfScoringFn(BaseScoringFn):
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise. A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
""" """
scoring_function_def = ScoringFnDef( def __init__(self, *args, **kwargs) -> None:
identifier="subset_of", super().__init__(*args, **kwargs)
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.", self.defs_paths = [FN_DEFS_PATH / "subset_of.json"]
parameters=[],
return_type=NumberType(),
)
async def score_row( async def score_row(
self, self,
input_row: Dict[str, Any], input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "subset_of", scoring_fn_identifier: Optional[str] = "subset_of",
) -> ScoringResultRow: ) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert (
"generated_answer" in input_row
), "Generated answer not found in input row."
expected_answer = input_row["expected_answer"] expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"] generated_answer = input_row["generated_answer"]
score = 1.0 if expected_answer in generated_answer else 0.0 score = 1.0 if expected_answer in generated_answer else 0.0

View file

@ -50,7 +50,9 @@ async def test_scoring_functions_list(scoring_settings):
assert isinstance(scoring_functions, list) assert isinstance(scoring_functions, list)
assert len(scoring_functions) > 0 assert len(scoring_functions) > 0
function_ids = [f.identifier for f in scoring_functions] function_ids = [f.identifier for f in scoring_functions]
assert "equality" in function_ids assert "meta-reference::equality" in function_ids
assert "meta-reference::subset_of" in function_ids
assert "meta-reference::llm_as_judge_8b_correctness" in function_ids
@pytest.mark.asyncio @pytest.mark.asyncio
@ -64,9 +66,15 @@ async def test_scoring_score(scoring_settings):
response = await scoring_impl.score_batch( response = await scoring_impl.score_batch(
dataset_id=response[0].identifier, dataset_id=response[0].identifier,
scoring_functions=["equality", "subset_of"], scoring_functions=[
"meta-reference::equality",
"meta-reference::subset_of",
"meta-reference::llm_as_judge_8b_correctness",
],
) )
assert len(response.results) == 2 print(response)
assert "equality" in response.results assert len(response.results) == 3
assert "subset_of" in response.results assert "meta-reference::equality" in response.results
assert "meta-reference::subset_of" in response.results
assert "meta-reference::llm_as_judge_8b_correctness" in response.results