score batch impl

This commit is contained in:
Xi Yan 2024-10-23 16:19:25 -07:00
parent 4b1d7da030
commit eb572faf6f
5 changed files with 38 additions and 7 deletions

View file

@ -17,6 +17,6 @@ async def get_provider_impl(
print("get_provider_impl", deps)
from .scoring import MetaReferenceScoringImpl
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio])
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets])
await impl.initialize()
return impl

View file

@ -10,6 +10,7 @@ from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from termcolor import cprint
@ -29,7 +30,10 @@ SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORE
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
def __init__(
self, config: MetaReferenceScoringConfig, datasetio_api: DatasetIO
self,
config: MetaReferenceScoringConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
@ -50,7 +54,15 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score_batch(
self, dataset_id: str, scoring_functions: List[str]
) -> ScoreBatchResponse:
print("score_batch")
rows_paginated = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
)
res = await self.score(
input_rows=rows_paginated.rows, scoring_functions=scoring_functions
)
cprint(f"res: {res}", "green")
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]

View file

@ -19,6 +19,7 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.impls.meta_reference.scoring.MetaReferenceScoringConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
),
]