forked from phoenix-oss/llama-stack-mirror
		
	* wip * dataset validation * test_scoring * cleanup * clean up test * comments * error checking * dataset client * test client: * datasetio client * clean up * basic scoring function works * scorer wip * equality scorer * score batch impl * score batch * update scoring test * refactor * validate scorer input * address comments * add all rows scores to ScoringResult * bugfix * scoring function def rename
		
			
				
	
	
		
			132 lines
		
	
	
	
		
			4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			132 lines
		
	
	
	
		
			4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import asyncio
 | |
| import os
 | |
| from pathlib import Path
 | |
| 
 | |
| import fire
 | |
| import httpx
 | |
| from termcolor import cprint
 | |
| 
 | |
| from llama_stack.apis.datasets import *  # noqa: F403
 | |
| from llama_stack.apis.scoring import *  # noqa: F403
 | |
| from llama_stack.apis.common.type_system import *  # noqa: F403
 | |
| from llama_stack.apis.datasetio.client import DatasetIOClient
 | |
| from llama_stack.apis.datasets.client import DatasetsClient
 | |
| from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file
 | |
| 
 | |
| 
 | |
| class ScoringClient(Scoring):
 | |
|     def __init__(self, base_url: str):
 | |
|         self.base_url = base_url
 | |
| 
 | |
|     async def initialize(self) -> None:
 | |
|         pass
 | |
| 
 | |
|     async def shutdown(self) -> None:
 | |
|         pass
 | |
| 
 | |
|     async def score_batch(
 | |
|         self, dataset_id: str, scoring_functions: List[str]
 | |
|     ) -> ScoreBatchResponse:
 | |
|         async with httpx.AsyncClient() as client:
 | |
|             response = await client.post(
 | |
|                 f"{self.base_url}/scoring/score_batch",
 | |
|                 json={
 | |
|                     "dataset_id": dataset_id,
 | |
|                     "scoring_functions": scoring_functions,
 | |
|                 },
 | |
|                 headers={"Content-Type": "application/json"},
 | |
|                 timeout=60,
 | |
|             )
 | |
|             response.raise_for_status()
 | |
|             if not response.json():
 | |
|                 return
 | |
| 
 | |
|             return ScoreBatchResponse(**response.json())
 | |
| 
 | |
|     async def score(
 | |
|         self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
 | |
|     ) -> ScoreResponse:
 | |
|         async with httpx.AsyncClient() as client:
 | |
|             response = await client.post(
 | |
|                 f"{self.base_url}/scoring/score",
 | |
|                 json={
 | |
|                     "input_rows": input_rows,
 | |
|                     "scoring_functions": scoring_functions,
 | |
|                 },
 | |
|                 headers={"Content-Type": "application/json"},
 | |
|                 timeout=60,
 | |
|             )
 | |
|             response.raise_for_status()
 | |
|             if not response.json():
 | |
|                 return
 | |
| 
 | |
|             return ScoreResponse(**response.json())
 | |
| 
 | |
| 
 | |
| async def run_main(host: str, port: int):
 | |
|     client = DatasetsClient(f"http://{host}:{port}")
 | |
| 
 | |
|     # register dataset
 | |
|     test_file = (
 | |
|         Path(os.path.abspath(__file__)).parent.parent.parent
 | |
|         / "providers/tests/datasetio/test_dataset.csv"
 | |
|     )
 | |
|     test_url = data_url_from_file(str(test_file))
 | |
|     response = await client.register_dataset(
 | |
|         DatasetDefWithProvider(
 | |
|             identifier="test-dataset",
 | |
|             provider_id="meta0",
 | |
|             url=URL(
 | |
|                 uri=test_url,
 | |
|             ),
 | |
|             dataset_schema={
 | |
|                 "generated_answer": StringType(),
 | |
|                 "expected_answer": StringType(),
 | |
|                 "input_query": StringType(),
 | |
|             },
 | |
|         )
 | |
|     )
 | |
| 
 | |
|     # list datasets
 | |
|     list_dataset = await client.list_datasets()
 | |
|     cprint(list_dataset, "blue")
 | |
| 
 | |
|     # datsetio client to get the rows
 | |
|     datasetio_client = DatasetIOClient(f"http://{host}:{port}")
 | |
|     response = await datasetio_client.get_rows_paginated(
 | |
|         dataset_id="test-dataset",
 | |
|         rows_in_page=4,
 | |
|         page_token=None,
 | |
|         filter_condition=None,
 | |
|     )
 | |
|     cprint(f"Returned {len(response.rows)} rows \n {response}", "green")
 | |
| 
 | |
|     # scoring client to score the rows
 | |
|     scoring_client = ScoringClient(f"http://{host}:{port}")
 | |
|     response = await scoring_client.score(
 | |
|         input_rows=response.rows,
 | |
|         scoring_functions=["equality"],
 | |
|     )
 | |
|     cprint(f"score response={response}", "blue")
 | |
| 
 | |
|     # test scoring batch using datasetio api
 | |
|     scoring_client = ScoringClient(f"http://{host}:{port}")
 | |
|     response = await scoring_client.score_batch(
 | |
|         dataset_id="test-dataset",
 | |
|         scoring_functions=["equality"],
 | |
|     )
 | |
|     cprint(f"score_batch response={response}", "cyan")
 | |
| 
 | |
| 
 | |
| def main(host: str, port: int):
 | |
|     asyncio.run(run_main(host, port))
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     fire.Fire(main)
 |