mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 03:09:49 +00:00
more scoring function for rag
This commit is contained in:
parent
b94ab8d013
commit
9aa4a405ca
8 changed files with 132 additions and 10 deletions
|
|
@ -37,9 +37,15 @@ def data_url_from_file(file_path: str) -> str:
|
|||
|
||||
|
||||
async def register_dataset(
|
||||
datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset"
|
||||
datasets_impl: Datasets,
|
||||
for_generation=False,
|
||||
for_rag=False,
|
||||
dataset_id="test_dataset",
|
||||
):
|
||||
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
|
||||
if for_rag:
|
||||
test_file = Path(os.path.abspath(__file__)).parent / "test_rag_dataset.csv"
|
||||
else:
|
||||
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
|
||||
test_url = data_url_from_file(str(test_file))
|
||||
|
||||
if for_generation:
|
||||
|
|
@ -48,6 +54,13 @@ async def register_dataset(
|
|||
"input_query": StringType(),
|
||||
"chat_completion_input": ChatCompletionInputType(),
|
||||
}
|
||||
elif for_rag:
|
||||
dataset_schema = {
|
||||
"expected_answer": StringType(),
|
||||
"input_query": StringType(),
|
||||
"generated_answer": StringType(),
|
||||
"context": StringType(),
|
||||
}
|
||||
else:
|
||||
dataset_schema = {
|
||||
"expected_answer": StringType(),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue