This commit is contained in:
Xi Yan 2024-11-17 20:18:26 -08:00
parent ff99025875
commit 8fbbea8c43
3 changed files with 31 additions and 11 deletions

View file

@ -22,6 +22,7 @@ async def get_provider_impl(
deps[Api.datasets],
deps[Api.scoring],
deps[Api.inference],
deps[Api.agents],
)
await impl.initialize()
return impl

View file

@ -9,6 +9,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTask
@ -39,12 +40,14 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
datasets_api: Datasets,
scoring_api: Scoring,
inference_api: Inference,
agent_api: Agents,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.scoring_api = scoring_api
self.inference_api = inference_api
self.agent_api = agent_api
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs = {}
@ -126,18 +129,15 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
self.jobs[job_id] = res
return Job(job_id=job_id)
async def evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: EvalTaskConfig,
) -> EvaluateResponse:
async def _run_agent_generation(
self, task_config: EvalTaskConfig
) -> List[Dict[str, Any]]:
pass
async def _run_model_generation(
self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig
) -> List[Dict[str, Any]]:
candidate = task_config.eval_candidate
if candidate.type == "agent":
raise NotImplementedError(
"Evaluation with generation has not been implemented for agents"
)
assert (
candidate.sampling_params.max_tokens is not None
), "SamplingParams.max_tokens must be provided"
@ -179,6 +179,24 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
else:
raise ValueError("Invalid input row")
return generations
async def evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: EvalTaskConfig,
) -> EvaluateResponse:
candidate = task_config.eval_candidate
if candidate.type == "agent":
raise NotImplementedError(
"Evaluation with generation has not been implemented for agents"
)
if candidate.type == "model":
generations = await self._run_model_generation(input_rows, task_config)
# scoring with generated_answer
score_input_rows = [
input_r | generated_r

View file

@ -22,6 +22,7 @@ def available_providers() -> List[ProviderSpec]:
Api.datasets,
Api.scoring,
Api.inference,
Api.agents,
],
),
]