From 5cf7779b8fa6777c21886164c2586cc44a501b38 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 15 Mar 2025 17:36:39 -0700 Subject: [PATCH] fix integeration --- .../inline/eval/meta_reference/eval.py | 76 ++++++++++++++----- tests/integration/eval/test_eval.py | 40 +++++++--- 2 files changed, 87 insertions(+), 29 deletions(-) diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index ae5b81a09..20c7cca16 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -19,11 +19,7 @@ from llama_stack.providers.datatypes import BenchmarksProtocolPrivate from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( MEMORY_QUERY_TOOL, ) -from llama_stack.providers.utils.common.data_schema_validator import ( - ColumnName, - get_valid_schemas, - validate_dataset_schema, -) +from llama_stack.providers.utils.common.data_schema_validator import ColumnName from llama_stack.providers.utils.kvstore import kvstore_impl from .....apis.common.job_types import Job @@ -89,10 +85,14 @@ class MetaReferenceEvalImpl( dataset_id = task_def.dataset_id scoring_functions = task_def.scoring_functions dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) - validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)) + # TODO: validate dataset schema all_rows = await self.datasetio_api.iterrows( dataset_id=dataset_id, - limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples), + limit=( + -1 + if benchmark_config.num_examples is None + else benchmark_config.num_examples + ), ) res = await self.evaluate_rows( benchmark_id=benchmark_id, @@ -118,10 +118,14 @@ class MetaReferenceEvalImpl( for i, x in tqdm(enumerate(input_rows)): assert ColumnName.chat_completion_input.value in x, "Invalid input row" input_messages = json.loads(x[ColumnName.chat_completion_input.value]) - input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"] + input_messages = [ + UserMessage(**x) for x in input_messages if x["role"] == "user" + ] # NOTE: only single-turn agent generation is supported. Create a new session for each input row - session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}") + session_create_response = await self.agents_api.create_agent_session( + agent_id, f"session-{i}" + ) session_id = session_create_response.session_id turn_request = dict( @@ -130,7 +134,12 @@ class MetaReferenceEvalImpl( messages=input_messages, stream=True, ) - turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)] + turn_response = [ + chunk + async for chunk in await self.agents_api.create_agent_turn( + **turn_request + ) + ] final_event = turn_response[-1].event.payload # check if there's a memory retrieval step and extract the context @@ -139,10 +148,14 @@ class MetaReferenceEvalImpl( if step.step_type == StepType.tool_execution.value: for tool_response in step.tool_responses: if tool_response.tool_name == MEMORY_QUERY_TOOL: - memory_rag_context = " ".join(x.text for x in tool_response.content) + memory_rag_context = " ".join( + x.text for x in tool_response.content + ) agent_generation = {} - agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content + agent_generation[ColumnName.generated_answer.value] = ( + final_event.turn.output_message.content + ) if memory_rag_context: agent_generation[ColumnName.context.value] = memory_rag_context @@ -154,7 +167,9 @@ class MetaReferenceEvalImpl( self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig ) -> List[Dict[str, Any]]: candidate = benchmark_config.eval_candidate - assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided" + assert ( + candidate.sampling_params.max_tokens is not None + ), "SamplingParams.max_tokens must be provided" generations = [] for x in tqdm(input_rows): @@ -165,21 +180,39 @@ class MetaReferenceEvalImpl( content=input_content, sampling_params=candidate.sampling_params, ) - generations.append({ColumnName.generated_answer.value: response.completion_message.content}) + generations.append( + { + ColumnName.generated_answer.value: response.completion_message.content + } + ) elif ColumnName.chat_completion_input.value in x: - chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value]) - input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"] + chat_completion_input_json = json.loads( + x[ColumnName.chat_completion_input.value] + ) + input_messages = [ + UserMessage(**x) + for x in chat_completion_input_json + if x["role"] == "user" + ] messages = [] if candidate.system_message: messages.append(candidate.system_message) - messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"] + messages += [ + SystemMessage(**x) + for x in chat_completion_input_json + if x["role"] == "system" + ] messages += input_messages response = await self.inference_api.chat_completion( model_id=candidate.model, messages=messages, sampling_params=candidate.sampling_params, ) - generations.append({ColumnName.generated_answer.value: response.completion_message.content}) + generations.append( + { + ColumnName.generated_answer.value: response.completion_message.content + } + ) else: raise ValueError("Invalid input row") @@ -202,7 +235,8 @@ class MetaReferenceEvalImpl( # scoring with generated_answer score_input_rows = [ - input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False) + input_r | generated_r + for input_r, generated_r in zip(input_rows, generations, strict=False) ] if benchmark_config.scoring_params is not None: @@ -211,7 +245,9 @@ class MetaReferenceEvalImpl( for scoring_fn_id in scoring_functions } else: - scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions} + scoring_functions_dict = { + scoring_fn_id: None for scoring_fn_id in scoring_functions + } score_response = await self.scoring_api.score( input_rows=score_input_rows, scoring_functions=scoring_functions_dict diff --git a/tests/integration/eval/test_eval.py b/tests/integration/eval/test_eval.py index e25daabbe..ee276e72b 100644 --- a/tests/integration/eval/test_eval.py +++ b/tests/integration/eval/test_eval.py @@ -3,11 +3,13 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import os import uuid +from pathlib import Path import pytest -from ..datasetio.test_datasetio import register_dataset +from ..datasets.test_datasets import data_url_from_file # How to run this test: # @@ -16,12 +18,20 @@ from ..datasetio.test_datasetio import register_dataset @pytest.mark.parametrize("scoring_fn_id", ["basic::equality"]) def test_evaluate_rows(llama_stack_client, text_model_id, scoring_fn_id): - register_dataset(llama_stack_client, for_generation=True, dataset_id="test_dataset_for_eval") + dataset = llama_stack_client.datasets.register( + purpose="eval/messages-answer", + source={ + "type": "uri", + "uri": data_url_from_file( + Path(__file__).parent.parent / "datasets" / "test_dataset.csv" + ), + }, + ) response = llama_stack_client.datasets.list() - assert any(x.identifier == "test_dataset_for_eval" for x in response) + assert any(x.identifier == dataset.identifier for x in response) rows = llama_stack_client.datasets.iterrows( - dataset_id="test_dataset_for_eval", + dataset_id=dataset.identifier, limit=3, ) assert len(rows.data) == 3 @@ -32,7 +42,7 @@ def test_evaluate_rows(llama_stack_client, text_model_id, scoring_fn_id): benchmark_id = str(uuid.uuid4()) llama_stack_client.benchmarks.register( benchmark_id=benchmark_id, - dataset_id="test_dataset_for_eval", + dataset_id=dataset.identifier, scoring_functions=scoring_functions, ) list_benchmarks = llama_stack_client.benchmarks.list() @@ -59,11 +69,19 @@ def test_evaluate_rows(llama_stack_client, text_model_id, scoring_fn_id): @pytest.mark.parametrize("scoring_fn_id", ["basic::subset_of"]) def test_evaluate_benchmark(llama_stack_client, text_model_id, scoring_fn_id): - register_dataset(llama_stack_client, for_generation=True, dataset_id="test_dataset_for_eval_2") + dataset = llama_stack_client.datasets.register( + purpose="eval/messages-answer", + source={ + "type": "uri", + "uri": data_url_from_file( + Path(__file__).parent.parent / "datasets" / "test_dataset.csv" + ), + }, + ) benchmark_id = str(uuid.uuid4()) llama_stack_client.benchmarks.register( benchmark_id=benchmark_id, - dataset_id="test_dataset_for_eval_2", + dataset_id=dataset.identifier, scoring_functions=[scoring_fn_id], ) @@ -80,10 +98,14 @@ def test_evaluate_benchmark(llama_stack_client, text_model_id, scoring_fn_id): }, ) assert response.job_id == "0" - job_status = llama_stack_client.eval.jobs.status(job_id=response.job_id, benchmark_id=benchmark_id) + job_status = llama_stack_client.eval.jobs.status( + job_id=response.job_id, benchmark_id=benchmark_id + ) assert job_status and job_status == "completed" - eval_response = llama_stack_client.eval.jobs.retrieve(job_id=response.job_id, benchmark_id=benchmark_id) + eval_response = llama_stack_client.eval.jobs.retrieve( + job_id=response.job_id, benchmark_id=benchmark_id + ) assert eval_response is not None assert len(eval_response.generations) == 5 assert scoring_fn_id in eval_response.scores