test eval works

This commit is contained in:
Xi Yan 2024-11-06 21:40:38 -08:00
parent 413a1b6d8b
commit 3f1ac29d57
3 changed files with 126 additions and 13 deletions

View file

@ -16,6 +16,10 @@ from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
from llama_stack.providers.inline.meta_reference.eval.eval import (
DEFAULT_EVAL_TASK_IDENTIFIER,
)
class MemoryRouter(Memory):
"""Routes to an provider based on the memory bank identifier"""
@ -280,7 +284,13 @@ class EvalRouter(Eval):
eval_task_def: EvalTaskDef,
eval_task_config: EvalTaskConfig,
) -> Job:
pass
# NOTE: We need to use DEFAULT_EVAL_TASK_IDENTIFIER to make the router work for all app evals
return await self.routing_table.get_provider_impl(
DEFAULT_EVAL_TASK_IDENTIFIER
).run_eval(
eval_task_def=eval_task_def,
eval_task_config=eval_task_config,
)
@webmethod(route="/eval/evaluate_rows", method="POST")
async def evaluate_rows(
@ -289,13 +299,27 @@ class EvalRouter(Eval):
scoring_functions: List[str],
eval_task_config: EvalTaskConfig, # type: ignore
) -> EvaluateResponse:
pass
# NOTE: This is to deal with the case where we do not pre-register an eval benchmark_task
# We use default DEFAULT_EVAL_TASK_IDENTIFIER as identifier
return await self.routing_table.get_provider_impl(
DEFAULT_EVAL_TASK_IDENTIFIER
).evaluate_rows(
input_rows=input_rows,
scoring_functions=scoring_functions,
eval_task_config=eval_task_config,
)
async def job_status(self, job_id: str) -> Optional[JobStatus]:
pass
return await self.routing_table.get_provider_impl(
DEFAULT_EVAL_TASK_IDENTIFIER
).job_status(job_id)
async def job_cancel(self, job_id: str) -> None:
pass
await self.routing_table.get_provider_impl(
DEFAULT_EVAL_TASK_IDENTIFIER
).job_cancel(job_id)
async def job_result(self, job_id: str) -> EvaluateResponse:
pass
return await self.routing_table.get_provider_impl(
DEFAULT_EVAL_TASK_IDENTIFIER
).job_result(job_id)