migrate scoring fns to resource (#422)

* fix after rebase

* remove print

---------

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-11 17:28:48 -08:00 committed by GitHub
parent 3802edfc50
commit 0a3b3d5fb6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 113 additions and 62 deletions

View file

@ -37,7 +37,7 @@ class ScoreResponse(BaseModel):
class ScoringFunctionStore(Protocol):
def get_scoring_function(self, name: str) -> ScoringFnDefWithProvider: ...
def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
@runtime_checkable

View file

@ -22,19 +22,21 @@ from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up?
@json_schema_type
class ScoringConfigType(Enum):
class ScoringFnParamsType(Enum):
llm_as_judge = "llm_as_judge"
regex_parser = "regex_parser"
@json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal[ScoringConfigType.llm_as_judge.value] = (
ScoringConfigType.llm_as_judge.value
type: Literal[ScoringFnParamsType.llm_as_judge.value] = (
ScoringFnParamsType.llm_as_judge.value
)
judge_model: str
prompt_template: Optional[str] = None
@ -46,8 +48,8 @@ class LLMAsJudgeScoringFnParams(BaseModel):
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
type: Literal[ScoringConfigType.regex_parser.value] = (
ScoringConfigType.regex_parser.value
type: Literal[ScoringFnParamsType.regex_parser.value] = (
ScoringFnParamsType.regex_parser.value
)
parsing_regexes: Optional[List[str]] = Field(
description="Regex to extract the answer from generated response",
@ -65,8 +67,10 @@ ScoringFnParams = Annotated[
@json_schema_type
class ScoringFnDef(BaseModel):
identifier: str
class ScoringFn(Resource):
type: Literal[ResourceType.scoring_function.value] = (
ResourceType.scoring_function.value
)
description: Optional[str] = None
metadata: Dict[str, Any] = Field(
default_factory=dict,
@ -79,28 +83,23 @@ class ScoringFnDef(BaseModel):
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
default=None,
)
# We can optionally add information here to support packaging of code, etc.
@json_schema_type
class ScoringFnDefWithProvider(ScoringFnDef):
type: Literal["scoring_fn"] = "scoring_fn"
provider_id: str = Field(
description="ID of the provider which serves this dataset",
)
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring_functions/list", method="GET")
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: ...
async def list_scoring_functions(self) -> List[ScoringFn]: ...
@webmethod(route="/scoring_functions/get", method="GET")
async def get_scoring_function(
self, name: str
) -> Optional[ScoringFnDefWithProvider]: ...
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: ...
@webmethod(route="/scoring_functions/register", method="POST")
async def register_scoring_function(
self, function_def: ScoringFnDefWithProvider
self,
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[ScoringFnParams] = None,
) -> None: ...

View file

@ -35,7 +35,7 @@ RoutableObject = Union[
Shield,
MemoryBank,
Dataset,
ScoringFnDef,
ScoringFn,
]
@ -45,7 +45,7 @@ RoutableObjectWithProvider = Annotated[
Shield,
MemoryBank,
Dataset,
ScoringFnDefWithProvider,
ScoringFn,
],
Field(discriminator="type"),
]

View file

@ -81,7 +81,10 @@ class CommonRoutingTableImpl(RoutingTable):
# so we should just override the provider in-place
obj.provider_id = provider_id
else:
obj = cls(**obj.model_dump(), provider_id=provider_id)
# Create a copy of the model data and explicitly set provider_id
model_data = obj.model_dump()
model_data["provider_id"] = provider_id
obj = cls(**model_data)
await self.dist_registry.register(obj)
# Register all objects from providers
@ -101,7 +104,7 @@ class CommonRoutingTableImpl(RoutingTable):
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFnDefWithProvider)
await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval:
p.eval_task_store = self
@ -340,18 +343,41 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
return await self.get_all_with_type("scoring_fn")
async def list_scoring_functions(self) -> List[ScoringFn]:
return await self.get_all_with_type(ResourceType.scoring_function.value)
async def get_scoring_function(
self, name: str
) -> Optional[ScoringFnDefWithProvider]:
return await self.get_object_by_identifier(name)
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier(scoring_fn_id)
async def register_scoring_function(
self, function_def: ScoringFnDefWithProvider
self,
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[ScoringFnParams] = None,
) -> None:
await self.register_object(function_def)
if params is None:
params = {}
if provider_scoring_fn_id is None:
provider_scoring_fn_id = scoring_fn_id
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
scoring_fn = ScoringFn(
identifier=scoring_fn_id,
description=description,
return_type=return_type,
provider_resource_id=provider_scoring_fn_id,
params=params,
)
scoring_fn.provider_id = provider_id
await self.register_object(scoring_fn)
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):

View file

@ -15,7 +15,7 @@ from llama_stack.apis.datasets import Dataset
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.memory_banks.memory_banks import MemoryBank
from llama_stack.apis.models import Model
from llama_stack.apis.scoring_functions import ScoringFnDef
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.shields import Shield
@ -61,9 +61,9 @@ class DatasetsProtocolPrivate(Protocol):
class ScoringFunctionsProtocolPrivate(Protocol):
async def list_scoring_functions(self) -> List[ScoringFnDef]: ...
async def list_scoring_functions(self) -> List[ScoringFn]: ...
async def register_scoring_function(self, function_def: ScoringFnDef) -> None: ...
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ...
class EvalTasksProtocolPrivate(Protocol):

View file

@ -48,7 +48,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFnDef]:
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
@ -57,7 +57,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
raise NotImplementedError(
"Registering scoring function not allowed for braintrust provider"
)

View file

@ -5,12 +5,14 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
from llama_stack.apis.scoring_functions import ScoringFn
answer_correctness_fn_def = ScoringFnDef(
answer_correctness_fn_def = ScoringFn(
identifier="braintrust::answer-correctness",
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
parameters=[],
params=None,
provider_id="braintrust",
provider_resource_id="answer-correctness",
return_type=NumberType(),
)

View file

@ -5,12 +5,14 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
from llama_stack.apis.scoring_functions import ScoringFn
factuality_fn_def = ScoringFnDef(
factuality_fn_def = ScoringFn(
identifier="braintrust::factuality",
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
parameters=[],
params=None,
provider_id="braintrust",
provider_resource_id="factuality",
return_type=NumberType(),
)

View file

@ -52,7 +52,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFnDef]:
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [
fn_def
for impl in self.scoring_fn_id_impls.values()
@ -66,7 +66,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:

View file

@ -24,15 +24,15 @@ class BaseScoringFn(ABC):
def __str__(self) -> str:
return self.__class__.__name__
def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]:
def get_supported_scoring_fn_defs(self) -> List[ScoringFn]:
return [x for x in self.supported_fn_defs_registry.values()]
def register_scoring_fn_def(self, scoring_fn_def: ScoringFnDef) -> None:
if scoring_fn_def.identifier in self.supported_fn_defs_registry:
def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None:
if scoring_fn.identifier in self.supported_fn_defs_registry:
raise ValueError(
f"Scoring function def with identifier {scoring_fn_def.identifier} already exists."
f"Scoring function def with identifier {scoring_fn.identifier} already exists."
)
self.supported_fn_defs_registry[scoring_fn_def.identifier] = scoring_fn_def
self.supported_fn_defs_registry[scoring_fn.identifier] = scoring_fn
@abstractmethod
async def score_row(

View file

@ -5,11 +5,14 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
from llama_stack.apis.scoring_functions import ScoringFn
equality = ScoringFnDef(
equality = ScoringFn(
identifier="meta-reference::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
params=None,
provider_id="meta-reference",
provider_resource_id="equality",
return_type=NumberType(),
)

View file

@ -5,11 +5,13 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
from llama_stack.apis.scoring_functions import ScoringFn
llm_as_judge_base = ScoringFnDef(
llm_as_judge_base = ScoringFn(
identifier="meta-reference::llm_as_judge_base",
description="Llm As Judge Scoring Function",
return_type=NumberType(),
provider_id="meta-reference",
provider_resource_id="llm-as-judge-base",
)

View file

@ -56,10 +56,12 @@ MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[]|[]|[]|[])"
)
regex_parser_multiple_choice_answer = ScoringFnDef(
regex_parser_multiple_choice_answer = ScoringFn(
identifier="meta-reference::regex_parser_multiple_choice_answer",
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
return_type=NumberType(),
provider_id="meta-reference",
provider_resource_id="regex-parser-multiple-choice-answer",
params=RegexParserScoringFnParams(
parsing_regexes=[
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)

View file

@ -5,12 +5,13 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFnDef
from llama_stack.apis.scoring_functions import ScoringFn
subset_of = ScoringFnDef(
subset_of = ScoringFn(
identifier="meta-reference::subset_of",
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
parameters=[],
return_type=NumberType(),
provider_id="meta-reference",
provider_resource_id="subset-of",
)

View file

@ -42,7 +42,7 @@ class RegexParserScoringFn(BaseScoringFn):
assert (
fn_def.params is not None
and fn_def.params.type == ScoringConfigType.regex_parser.value
and fn_def.params.type == ScoringFnParamsType.regex_parser.value
), f"RegexParserScoringFnParams not found for {fn_def}."
expected_answer = input_row["expected_answer"]

View file

@ -48,7 +48,7 @@ SCORING_FIXTURES = ["meta_reference", "remote", "braintrust"]
@pytest_asyncio.fixture(scope="session")
async def scoring_stack(request):
async def scoring_stack(request, inference_model):
fixture_dict = request.param
providers = {}
@ -65,4 +65,18 @@ async def scoring_stack(request):
provider_data,
)
provider_id = providers["inference"][0].provider_id
await impls[Api.models].register_model(
model_id=inference_model,
provider_id=provider_id,
)
await impls[Api.models].register_model(
model_id="Llama3.1-405B-Instruct",
provider_id=provider_id,
)
await impls[Api.models].register_model(
model_id="Llama3.1-8B-Instruct",
provider_id=provider_id,
)
return impls