score batch impl

This commit is contained in:
Xi Yan 2024-10-23 16:19:25 -07:00
parent 4b1d7da030
commit eb572faf6f
5 changed files with 38 additions and 7 deletions

View file

@ -36,7 +36,10 @@ class ScoringClient(Scoring):
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/scoring/score_batch",
params={},
json={
"dataset_id": dataset_id,
"scoring_functions": scoring_functions,
},
headers={"Content-Type": "application/json"},
timeout=60,
)
@ -44,7 +47,7 @@ class ScoringClient(Scoring):
if not response.json():
return
return ScoreResponse(**response.json())
return ScoreBatchResponse(**response.json())
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
@ -112,6 +115,14 @@ async def run_main(host: str, port: int):
)
cprint(f"scoring response={response}", "blue")
# test scoring batch using datasetio api
scoring_client = ScoringClient(f"http://{host}:{port}")
response = await scoring_client.score_batch(
dataset_id="test-dataset",
scoring_functions=["equality"],
)
cprint(f"scoring response={response}", "blue")
def main(host: str, port: int):
asyncio.run(run_main(host, port))

View file

@ -211,8 +211,15 @@ class ScoringRouter(Scoring):
async def score_batch(
self, dataset_id: str, scoring_functions: List[str]
) -> ScoreBatchResponse:
# TODO
pass
print("Score Batch!")
for fn_identifier in scoring_functions:
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score_batch(
dataset_id=dataset_id,
scoring_functions=[fn_identifier],
)
print(score_response)
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]

View file

@ -17,6 +17,6 @@ async def get_provider_impl(
print("get_provider_impl", deps)
from .scoring import MetaReferenceScoringImpl
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio])
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets])
await impl.initialize()
return impl

View file

@ -10,6 +10,7 @@ from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from termcolor import cprint
@ -29,7 +30,10 @@ SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORE
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
def __init__(
self, config: MetaReferenceScoringConfig, datasetio_api: DatasetIO
self,
config: MetaReferenceScoringConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
@ -50,7 +54,15 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score_batch(
self, dataset_id: str, scoring_functions: List[str]
) -> ScoreBatchResponse:
print("score_batch")
rows_paginated = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
)
res = await self.score(
input_rows=rows_paginated.rows, scoring_functions=scoring_functions
)
cprint(f"res: {res}", "green")
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]

View file

@ -19,6 +19,7 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.impls.meta_reference.scoring.MetaReferenceScoringConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
),
]