mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
migrate scoring fns to resource (#422)
* fix after rebase * remove print --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
3802edfc50
commit
0a3b3d5fb6
16 changed files with 113 additions and 62 deletions
|
@ -37,7 +37,7 @@ class ScoreResponse(BaseModel):
|
|||
|
||||
|
||||
class ScoringFunctionStore(Protocol):
|
||||
def get_scoring_function(self, name: str) -> ScoringFnDefWithProvider: ...
|
||||
def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
|
@ -22,19 +22,21 @@ from typing_extensions import Annotated
|
|||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
||||
|
||||
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
||||
# with standard metrics so they can be rolled up?
|
||||
@json_schema_type
|
||||
class ScoringConfigType(Enum):
|
||||
class ScoringFnParamsType(Enum):
|
||||
llm_as_judge = "llm_as_judge"
|
||||
regex_parser = "regex_parser"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringConfigType.llm_as_judge.value] = (
|
||||
ScoringConfigType.llm_as_judge.value
|
||||
type: Literal[ScoringFnParamsType.llm_as_judge.value] = (
|
||||
ScoringFnParamsType.llm_as_judge.value
|
||||
)
|
||||
judge_model: str
|
||||
prompt_template: Optional[str] = None
|
||||
|
@ -46,8 +48,8 @@ class LLMAsJudgeScoringFnParams(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class RegexParserScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringConfigType.regex_parser.value] = (
|
||||
ScoringConfigType.regex_parser.value
|
||||
type: Literal[ScoringFnParamsType.regex_parser.value] = (
|
||||
ScoringFnParamsType.regex_parser.value
|
||||
)
|
||||
parsing_regexes: Optional[List[str]] = Field(
|
||||
description="Regex to extract the answer from generated response",
|
||||
|
@ -65,8 +67,10 @@ ScoringFnParams = Annotated[
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ScoringFnDef(BaseModel):
|
||||
identifier: str
|
||||
class ScoringFn(Resource):
|
||||
type: Literal[ResourceType.scoring_function.value] = (
|
||||
ResourceType.scoring_function.value
|
||||
)
|
||||
description: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
|
@ -79,28 +83,23 @@ class ScoringFnDef(BaseModel):
|
|||
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
||||
default=None,
|
||||
)
|
||||
# We can optionally add information here to support packaging of code, etc.
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoringFnDefWithProvider(ScoringFnDef):
|
||||
type: Literal["scoring_fn"] = "scoring_fn"
|
||||
provider_id: str = Field(
|
||||
description="ID of the provider which serves this dataset",
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ScoringFunctions(Protocol):
|
||||
@webmethod(route="/scoring_functions/list", method="GET")
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: ...
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
||||
|
||||
@webmethod(route="/scoring_functions/get", method="GET")
|
||||
async def get_scoring_function(
|
||||
self, name: str
|
||||
) -> Optional[ScoringFnDefWithProvider]: ...
|
||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: ...
|
||||
|
||||
@webmethod(route="/scoring_functions/register", method="POST")
|
||||
async def register_scoring_function(
|
||||
self, function_def: ScoringFnDefWithProvider
|
||||
self,
|
||||
scoring_fn_id: str,
|
||||
description: str,
|
||||
return_type: ParamType,
|
||||
provider_scoring_fn_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[ScoringFnParams] = None,
|
||||
) -> None: ...
|
||||
|
|
|
@ -35,7 +35,7 @@ RoutableObject = Union[
|
|||
Shield,
|
||||
MemoryBank,
|
||||
Dataset,
|
||||
ScoringFnDef,
|
||||
ScoringFn,
|
||||
]
|
||||
|
||||
|
||||
|
@ -45,7 +45,7 @@ RoutableObjectWithProvider = Annotated[
|
|||
Shield,
|
||||
MemoryBank,
|
||||
Dataset,
|
||||
ScoringFnDefWithProvider,
|
||||
ScoringFn,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
|
@ -81,7 +81,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
# so we should just override the provider in-place
|
||||
obj.provider_id = provider_id
|
||||
else:
|
||||
obj = cls(**obj.model_dump(), provider_id=provider_id)
|
||||
# Create a copy of the model data and explicitly set provider_id
|
||||
model_data = obj.model_dump()
|
||||
model_data["provider_id"] = provider_id
|
||||
obj = cls(**model_data)
|
||||
await self.dist_registry.register(obj)
|
||||
|
||||
# Register all objects from providers
|
||||
|
@ -101,7 +104,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
elif api == Api.scoring:
|
||||
p.scoring_function_store = self
|
||||
scoring_functions = await p.list_scoring_functions()
|
||||
await add_objects(scoring_functions, pid, ScoringFnDefWithProvider)
|
||||
await add_objects(scoring_functions, pid, ScoringFn)
|
||||
|
||||
elif api == Api.eval:
|
||||
p.eval_task_store = self
|
||||
|
@ -340,18 +343,41 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
|
||||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
|
||||
return await self.get_all_with_type("scoring_fn")
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
return await self.get_all_with_type(ResourceType.scoring_function.value)
|
||||
|
||||
async def get_scoring_function(
|
||||
self, name: str
|
||||
) -> Optional[ScoringFnDefWithProvider]:
|
||||
return await self.get_object_by_identifier(name)
|
||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
||||
return await self.get_object_by_identifier(scoring_fn_id)
|
||||
|
||||
async def register_scoring_function(
|
||||
self, function_def: ScoringFnDefWithProvider
|
||||
self,
|
||||
scoring_fn_id: str,
|
||||
description: str,
|
||||
return_type: ParamType,
|
||||
provider_scoring_fn_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[ScoringFnParams] = None,
|
||||
) -> None:
|
||||
await self.register_object(function_def)
|
||||
if params is None:
|
||||
params = {}
|
||||
if provider_scoring_fn_id is None:
|
||||
provider_scoring_fn_id = scoring_fn_id
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
scoring_fn = ScoringFn(
|
||||
identifier=scoring_fn_id,
|
||||
description=description,
|
||||
return_type=return_type,
|
||||
provider_resource_id=provider_scoring_fn_id,
|
||||
params=params,
|
||||
)
|
||||
scoring_fn.provider_id = provider_id
|
||||
await self.register_object(scoring_fn)
|
||||
|
||||
|
||||
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.apis.datasets import Dataset
|
|||
from llama_stack.apis.eval_tasks import EvalTask
|
||||
from llama_stack.apis.memory_banks.memory_banks import MemoryBank
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
||||
|
||||
|
@ -61,9 +61,9 @@ class DatasetsProtocolPrivate(Protocol):
|
|||
|
||||
|
||||
class ScoringFunctionsProtocolPrivate(Protocol):
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDef]: ...
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None: ...
|
||||
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ...
|
||||
|
||||
|
||||
class EvalTasksProtocolPrivate(Protocol):
|
||||
|
|
|
@ -48,7 +48,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
|
||||
for f in scoring_fn_defs_list:
|
||||
assert f.identifier.startswith(
|
||||
|
@ -57,7 +57,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
|
||||
return scoring_fn_defs_list
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
||||
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
|
||||
raise NotImplementedError(
|
||||
"Registering scoring function not allowed for braintrust provider"
|
||||
)
|
||||
|
|
|
@ -5,12 +5,14 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
|
||||
|
||||
answer_correctness_fn_def = ScoringFnDef(
|
||||
answer_correctness_fn_def = ScoringFn(
|
||||
identifier="braintrust::answer-correctness",
|
||||
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
||||
parameters=[],
|
||||
params=None,
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="answer-correctness",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
|
|
@ -5,12 +5,14 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
|
||||
|
||||
factuality_fn_def = ScoringFnDef(
|
||||
factuality_fn_def = ScoringFn(
|
||||
identifier="braintrust::factuality",
|
||||
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
||||
parameters=[],
|
||||
params=None,
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="factuality",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
|
|
@ -52,7 +52,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
scoring_fn_defs_list = [
|
||||
fn_def
|
||||
for impl in self.scoring_fn_id_impls.values()
|
||||
|
@ -66,7 +66,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
|
||||
return scoring_fn_defs_list
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
||||
async def register_scoring_function(self, function_def: ScoringFn) -> None:
|
||||
raise NotImplementedError("Register scoring function not implemented yet")
|
||||
|
||||
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
||||
|
|
|
@ -24,15 +24,15 @@ class BaseScoringFn(ABC):
|
|||
def __str__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]:
|
||||
def get_supported_scoring_fn_defs(self) -> List[ScoringFn]:
|
||||
return [x for x in self.supported_fn_defs_registry.values()]
|
||||
|
||||
def register_scoring_fn_def(self, scoring_fn_def: ScoringFnDef) -> None:
|
||||
if scoring_fn_def.identifier in self.supported_fn_defs_registry:
|
||||
def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None:
|
||||
if scoring_fn.identifier in self.supported_fn_defs_registry:
|
||||
raise ValueError(
|
||||
f"Scoring function def with identifier {scoring_fn_def.identifier} already exists."
|
||||
f"Scoring function def with identifier {scoring_fn.identifier} already exists."
|
||||
)
|
||||
self.supported_fn_defs_registry[scoring_fn_def.identifier] = scoring_fn_def
|
||||
self.supported_fn_defs_registry[scoring_fn.identifier] = scoring_fn
|
||||
|
||||
@abstractmethod
|
||||
async def score_row(
|
||||
|
|
|
@ -5,11 +5,14 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
|
||||
|
||||
equality = ScoringFnDef(
|
||||
equality = ScoringFn(
|
||||
identifier="meta-reference::equality",
|
||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||
params=None,
|
||||
provider_id="meta-reference",
|
||||
provider_resource_id="equality",
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
|
|
@ -5,11 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
|
||||
|
||||
llm_as_judge_base = ScoringFnDef(
|
||||
llm_as_judge_base = ScoringFn(
|
||||
identifier="meta-reference::llm_as_judge_base",
|
||||
description="Llm As Judge Scoring Function",
|
||||
return_type=NumberType(),
|
||||
provider_id="meta-reference",
|
||||
provider_resource_id="llm-as-judge-base",
|
||||
)
|
||||
|
|
|
@ -56,10 +56,12 @@ MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
|
|||
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
|
||||
)
|
||||
|
||||
regex_parser_multiple_choice_answer = ScoringFnDef(
|
||||
regex_parser_multiple_choice_answer = ScoringFn(
|
||||
identifier="meta-reference::regex_parser_multiple_choice_answer",
|
||||
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
|
||||
return_type=NumberType(),
|
||||
provider_id="meta-reference",
|
||||
provider_resource_id="regex-parser-multiple-choice-answer",
|
||||
params=RegexParserScoringFnParams(
|
||||
parsing_regexes=[
|
||||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
|
||||
|
|
|
@ -5,12 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
|
||||
|
||||
subset_of = ScoringFnDef(
|
||||
subset_of = ScoringFn(
|
||||
identifier="meta-reference::subset_of",
|
||||
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
||||
parameters=[],
|
||||
return_type=NumberType(),
|
||||
provider_id="meta-reference",
|
||||
provider_resource_id="subset-of",
|
||||
)
|
||||
|
|
|
@ -42,7 +42,7 @@ class RegexParserScoringFn(BaseScoringFn):
|
|||
|
||||
assert (
|
||||
fn_def.params is not None
|
||||
and fn_def.params.type == ScoringConfigType.regex_parser.value
|
||||
and fn_def.params.type == ScoringFnParamsType.regex_parser.value
|
||||
), f"RegexParserScoringFnParams not found for {fn_def}."
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
|
|
|
@ -48,7 +48,7 @@ SCORING_FIXTURES = ["meta_reference", "remote", "braintrust"]
|
|||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def scoring_stack(request):
|
||||
async def scoring_stack(request, inference_model):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
|
@ -65,4 +65,18 @@ async def scoring_stack(request):
|
|||
provider_data,
|
||||
)
|
||||
|
||||
provider_id = providers["inference"][0].provider_id
|
||||
await impls[Api.models].register_model(
|
||||
model_id=inference_model,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
await impls[Api.models].register_model(
|
||||
model_id="Llama3.1-405B-Instruct",
|
||||
provider_id=provider_id,
|
||||
)
|
||||
await impls[Api.models].register_model(
|
||||
model_id="Llama3.1-8B-Instruct",
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
return impls
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue