mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
migrate scoring fns to resource (#422)
* fix after rebase * remove print --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
3802edfc50
commit
0a3b3d5fb6
16 changed files with 113 additions and 62 deletions
|
@ -37,7 +37,7 @@ class ScoreResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionStore(Protocol):
|
class ScoringFunctionStore(Protocol):
|
||||||
def get_scoring_function(self, name: str) -> ScoringFnDefWithProvider: ...
|
def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|
|
@ -22,19 +22,21 @@ from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
|
|
||||||
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
|
||||||
|
|
||||||
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
||||||
# with standard metrics so they can be rolled up?
|
# with standard metrics so they can be rolled up?
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ScoringConfigType(Enum):
|
class ScoringFnParamsType(Enum):
|
||||||
llm_as_judge = "llm_as_judge"
|
llm_as_judge = "llm_as_judge"
|
||||||
regex_parser = "regex_parser"
|
regex_parser = "regex_parser"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringConfigType.llm_as_judge.value] = (
|
type: Literal[ScoringFnParamsType.llm_as_judge.value] = (
|
||||||
ScoringConfigType.llm_as_judge.value
|
ScoringFnParamsType.llm_as_judge.value
|
||||||
)
|
)
|
||||||
judge_model: str
|
judge_model: str
|
||||||
prompt_template: Optional[str] = None
|
prompt_template: Optional[str] = None
|
||||||
|
@ -46,8 +48,8 @@ class LLMAsJudgeScoringFnParams(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RegexParserScoringFnParams(BaseModel):
|
class RegexParserScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringConfigType.regex_parser.value] = (
|
type: Literal[ScoringFnParamsType.regex_parser.value] = (
|
||||||
ScoringConfigType.regex_parser.value
|
ScoringFnParamsType.regex_parser.value
|
||||||
)
|
)
|
||||||
parsing_regexes: Optional[List[str]] = Field(
|
parsing_regexes: Optional[List[str]] = Field(
|
||||||
description="Regex to extract the answer from generated response",
|
description="Regex to extract the answer from generated response",
|
||||||
|
@ -65,8 +67,10 @@ ScoringFnParams = Annotated[
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ScoringFnDef(BaseModel):
|
class ScoringFn(Resource):
|
||||||
identifier: str
|
type: Literal[ResourceType.scoring_function.value] = (
|
||||||
|
ResourceType.scoring_function.value
|
||||||
|
)
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: Dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
|
@ -79,28 +83,23 @@ class ScoringFnDef(BaseModel):
|
||||||
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
# We can optionally add information here to support packaging of code, etc.
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ScoringFnDefWithProvider(ScoringFnDef):
|
|
||||||
type: Literal["scoring_fn"] = "scoring_fn"
|
|
||||||
provider_id: str = Field(
|
|
||||||
description="ID of the provider which serves this dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class ScoringFunctions(Protocol):
|
class ScoringFunctions(Protocol):
|
||||||
@webmethod(route="/scoring_functions/list", method="GET")
|
@webmethod(route="/scoring_functions/list", method="GET")
|
||||||
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: ...
|
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
||||||
|
|
||||||
@webmethod(route="/scoring_functions/get", method="GET")
|
@webmethod(route="/scoring_functions/get", method="GET")
|
||||||
async def get_scoring_function(
|
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: ...
|
||||||
self, name: str
|
|
||||||
) -> Optional[ScoringFnDefWithProvider]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/scoring_functions/register", method="POST")
|
@webmethod(route="/scoring_functions/register", method="POST")
|
||||||
async def register_scoring_function(
|
async def register_scoring_function(
|
||||||
self, function_def: ScoringFnDefWithProvider
|
self,
|
||||||
|
scoring_fn_id: str,
|
||||||
|
description: str,
|
||||||
|
return_type: ParamType,
|
||||||
|
provider_scoring_fn_id: Optional[str] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
params: Optional[ScoringFnParams] = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
|
@ -35,7 +35,7 @@ RoutableObject = Union[
|
||||||
Shield,
|
Shield,
|
||||||
MemoryBank,
|
MemoryBank,
|
||||||
Dataset,
|
Dataset,
|
||||||
ScoringFnDef,
|
ScoringFn,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ RoutableObjectWithProvider = Annotated[
|
||||||
Shield,
|
Shield,
|
||||||
MemoryBank,
|
MemoryBank,
|
||||||
Dataset,
|
Dataset,
|
||||||
ScoringFnDefWithProvider,
|
ScoringFn,
|
||||||
],
|
],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
|
@ -81,7 +81,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
# so we should just override the provider in-place
|
# so we should just override the provider in-place
|
||||||
obj.provider_id = provider_id
|
obj.provider_id = provider_id
|
||||||
else:
|
else:
|
||||||
obj = cls(**obj.model_dump(), provider_id=provider_id)
|
# Create a copy of the model data and explicitly set provider_id
|
||||||
|
model_data = obj.model_dump()
|
||||||
|
model_data["provider_id"] = provider_id
|
||||||
|
obj = cls(**model_data)
|
||||||
await self.dist_registry.register(obj)
|
await self.dist_registry.register(obj)
|
||||||
|
|
||||||
# Register all objects from providers
|
# Register all objects from providers
|
||||||
|
@ -101,7 +104,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
elif api == Api.scoring:
|
elif api == Api.scoring:
|
||||||
p.scoring_function_store = self
|
p.scoring_function_store = self
|
||||||
scoring_functions = await p.list_scoring_functions()
|
scoring_functions = await p.list_scoring_functions()
|
||||||
await add_objects(scoring_functions, pid, ScoringFnDefWithProvider)
|
await add_objects(scoring_functions, pid, ScoringFn)
|
||||||
|
|
||||||
elif api == Api.eval:
|
elif api == Api.eval:
|
||||||
p.eval_task_store = self
|
p.eval_task_store = self
|
||||||
|
@ -340,18 +343,41 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
|
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||||
return await self.get_all_with_type("scoring_fn")
|
return await self.get_all_with_type(ResourceType.scoring_function.value)
|
||||||
|
|
||||||
async def get_scoring_function(
|
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
||||||
self, name: str
|
return await self.get_object_by_identifier(scoring_fn_id)
|
||||||
) -> Optional[ScoringFnDefWithProvider]:
|
|
||||||
return await self.get_object_by_identifier(name)
|
|
||||||
|
|
||||||
async def register_scoring_function(
|
async def register_scoring_function(
|
||||||
self, function_def: ScoringFnDefWithProvider
|
self,
|
||||||
|
scoring_fn_id: str,
|
||||||
|
description: str,
|
||||||
|
return_type: ParamType,
|
||||||
|
provider_scoring_fn_id: Optional[str] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
params: Optional[ScoringFnParams] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.register_object(function_def)
|
if params is None:
|
||||||
|
params = {}
|
||||||
|
if provider_scoring_fn_id is None:
|
||||||
|
provider_scoring_fn_id = scoring_fn_id
|
||||||
|
if provider_id is None:
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
|
)
|
||||||
|
scoring_fn = ScoringFn(
|
||||||
|
identifier=scoring_fn_id,
|
||||||
|
description=description,
|
||||||
|
return_type=return_type,
|
||||||
|
provider_resource_id=provider_scoring_fn_id,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
scoring_fn.provider_id = provider_id
|
||||||
|
await self.register_object(scoring_fn)
|
||||||
|
|
||||||
|
|
||||||
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.apis.datasets import Dataset
|
||||||
from llama_stack.apis.eval_tasks import EvalTask
|
from llama_stack.apis.eval_tasks import EvalTask
|
||||||
from llama_stack.apis.memory_banks.memory_banks import MemoryBank
|
from llama_stack.apis.memory_banks.memory_banks import MemoryBank
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,9 +61,9 @@ class DatasetsProtocolPrivate(Protocol):
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsProtocolPrivate(Protocol):
|
class ScoringFunctionsProtocolPrivate(Protocol):
|
||||||
async def list_scoring_functions(self) -> List[ScoringFnDef]: ...
|
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
||||||
|
|
||||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None: ...
|
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class EvalTasksProtocolPrivate(Protocol):
|
class EvalTasksProtocolPrivate(Protocol):
|
||||||
|
|
|
@ -48,7 +48,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||||
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
|
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
|
||||||
for f in scoring_fn_defs_list:
|
for f in scoring_fn_defs_list:
|
||||||
assert f.identifier.startswith(
|
assert f.identifier.startswith(
|
||||||
|
@ -57,7 +57,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
|
|
||||||
return scoring_fn_defs_list
|
return scoring_fn_defs_list
|
||||||
|
|
||||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Registering scoring function not allowed for braintrust provider"
|
"Registering scoring function not allowed for braintrust provider"
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,12 +5,14 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
|
|
||||||
|
|
||||||
answer_correctness_fn_def = ScoringFnDef(
|
answer_correctness_fn_def = ScoringFn(
|
||||||
identifier="braintrust::answer-correctness",
|
identifier="braintrust::answer-correctness",
|
||||||
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
||||||
parameters=[],
|
params=None,
|
||||||
|
provider_id="braintrust",
|
||||||
|
provider_resource_id="answer-correctness",
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,12 +5,14 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
|
|
||||||
|
|
||||||
factuality_fn_def = ScoringFnDef(
|
factuality_fn_def = ScoringFn(
|
||||||
identifier="braintrust::factuality",
|
identifier="braintrust::factuality",
|
||||||
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
||||||
parameters=[],
|
params=None,
|
||||||
|
provider_id="braintrust",
|
||||||
|
provider_resource_id="factuality",
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -52,7 +52,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||||
scoring_fn_defs_list = [
|
scoring_fn_defs_list = [
|
||||||
fn_def
|
fn_def
|
||||||
for impl in self.scoring_fn_id_impls.values()
|
for impl in self.scoring_fn_id_impls.values()
|
||||||
|
@ -66,7 +66,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
|
|
||||||
return scoring_fn_defs_list
|
return scoring_fn_defs_list
|
||||||
|
|
||||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
async def register_scoring_function(self, function_def: ScoringFn) -> None:
|
||||||
raise NotImplementedError("Register scoring function not implemented yet")
|
raise NotImplementedError("Register scoring function not implemented yet")
|
||||||
|
|
||||||
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
||||||
|
|
|
@ -24,15 +24,15 @@ class BaseScoringFn(ABC):
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]:
|
def get_supported_scoring_fn_defs(self) -> List[ScoringFn]:
|
||||||
return [x for x in self.supported_fn_defs_registry.values()]
|
return [x for x in self.supported_fn_defs_registry.values()]
|
||||||
|
|
||||||
def register_scoring_fn_def(self, scoring_fn_def: ScoringFnDef) -> None:
|
def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None:
|
||||||
if scoring_fn_def.identifier in self.supported_fn_defs_registry:
|
if scoring_fn.identifier in self.supported_fn_defs_registry:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Scoring function def with identifier {scoring_fn_def.identifier} already exists."
|
f"Scoring function def with identifier {scoring_fn.identifier} already exists."
|
||||||
)
|
)
|
||||||
self.supported_fn_defs_registry[scoring_fn_def.identifier] = scoring_fn_def
|
self.supported_fn_defs_registry[scoring_fn.identifier] = scoring_fn
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def score_row(
|
async def score_row(
|
||||||
|
|
|
@ -5,11 +5,14 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
|
|
||||||
|
|
||||||
equality = ScoringFnDef(
|
equality = ScoringFn(
|
||||||
identifier="meta-reference::equality",
|
identifier="meta-reference::equality",
|
||||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||||
|
params=None,
|
||||||
|
provider_id="meta-reference",
|
||||||
|
provider_resource_id="equality",
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,11 +5,13 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
|
|
||||||
|
|
||||||
llm_as_judge_base = ScoringFnDef(
|
llm_as_judge_base = ScoringFn(
|
||||||
identifier="meta-reference::llm_as_judge_base",
|
identifier="meta-reference::llm_as_judge_base",
|
||||||
description="Llm As Judge Scoring Function",
|
description="Llm As Judge Scoring Function",
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
|
provider_id="meta-reference",
|
||||||
|
provider_resource_id="llm-as-judge-base",
|
||||||
)
|
)
|
||||||
|
|
|
@ -56,10 +56,12 @@ MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
|
||||||
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
|
r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
|
||||||
)
|
)
|
||||||
|
|
||||||
regex_parser_multiple_choice_answer = ScoringFnDef(
|
regex_parser_multiple_choice_answer = ScoringFn(
|
||||||
identifier="meta-reference::regex_parser_multiple_choice_answer",
|
identifier="meta-reference::regex_parser_multiple_choice_answer",
|
||||||
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
|
description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result",
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
|
provider_id="meta-reference",
|
||||||
|
provider_resource_id="regex-parser-multiple-choice-answer",
|
||||||
params=RegexParserScoringFnParams(
|
params=RegexParserScoringFnParams(
|
||||||
parsing_regexes=[
|
parsing_regexes=[
|
||||||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
|
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
|
||||||
|
|
|
@ -5,12 +5,13 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
|
|
||||||
|
|
||||||
subset_of = ScoringFnDef(
|
subset_of = ScoringFn(
|
||||||
identifier="meta-reference::subset_of",
|
identifier="meta-reference::subset_of",
|
||||||
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.",
|
||||||
parameters=[],
|
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
|
provider_id="meta-reference",
|
||||||
|
provider_resource_id="subset-of",
|
||||||
)
|
)
|
||||||
|
|
|
@ -42,7 +42,7 @@ class RegexParserScoringFn(BaseScoringFn):
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
fn_def.params is not None
|
fn_def.params is not None
|
||||||
and fn_def.params.type == ScoringConfigType.regex_parser.value
|
and fn_def.params.type == ScoringFnParamsType.regex_parser.value
|
||||||
), f"RegexParserScoringFnParams not found for {fn_def}."
|
), f"RegexParserScoringFnParams not found for {fn_def}."
|
||||||
|
|
||||||
expected_answer = input_row["expected_answer"]
|
expected_answer = input_row["expected_answer"]
|
||||||
|
|
|
@ -48,7 +48,7 @@ SCORING_FIXTURES = ["meta_reference", "remote", "braintrust"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def scoring_stack(request):
|
async def scoring_stack(request, inference_model):
|
||||||
fixture_dict = request.param
|
fixture_dict = request.param
|
||||||
|
|
||||||
providers = {}
|
providers = {}
|
||||||
|
@ -65,4 +65,18 @@ async def scoring_stack(request):
|
||||||
provider_data,
|
provider_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
provider_id = providers["inference"][0].provider_id
|
||||||
|
await impls[Api.models].register_model(
|
||||||
|
model_id=inference_model,
|
||||||
|
provider_id=provider_id,
|
||||||
|
)
|
||||||
|
await impls[Api.models].register_model(
|
||||||
|
model_id="Llama3.1-405B-Instruct",
|
||||||
|
provider_id=provider_id,
|
||||||
|
)
|
||||||
|
await impls[Api.models].register_model(
|
||||||
|
model_id="Llama3.1-8B-Instruct",
|
||||||
|
provider_id=provider_id,
|
||||||
|
)
|
||||||
|
|
||||||
return impls
|
return impls
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue