mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
refactor
This commit is contained in:
parent
ff99025875
commit
8fbbea8c43
3 changed files with 31 additions and 11 deletions
|
@ -22,6 +22,7 @@ async def get_provider_impl(
|
|||
deps[Api.datasets],
|
||||
deps[Api.scoring],
|
||||
deps[Api.inference],
|
||||
deps[Api.agents],
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -9,6 +9,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from .....apis.common.job_types import Job
|
||||
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval_tasks import EvalTask
|
||||
|
@ -39,12 +40,14 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
datasets_api: Datasets,
|
||||
scoring_api: Scoring,
|
||||
inference_api: Inference,
|
||||
agent_api: Agents,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.scoring_api = scoring_api
|
||||
self.inference_api = inference_api
|
||||
self.agent_api = agent_api
|
||||
|
||||
# TODO: assume sync job, will need jobs API for async scheduling
|
||||
self.jobs = {}
|
||||
|
@ -126,18 +129,15 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
self.jobs[job_id] = res
|
||||
return Job(job_id=job_id)
|
||||
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: EvalTaskConfig,
|
||||
) -> EvaluateResponse:
|
||||
async def _run_agent_generation(
|
||||
self, task_config: EvalTaskConfig
|
||||
) -> List[Dict[str, Any]]:
|
||||
pass
|
||||
|
||||
async def _run_model_generation(
|
||||
self, input_rows: List[Dict[str, Any]], task_config: EvalTaskConfig
|
||||
) -> List[Dict[str, Any]]:
|
||||
candidate = task_config.eval_candidate
|
||||
if candidate.type == "agent":
|
||||
raise NotImplementedError(
|
||||
"Evaluation with generation has not been implemented for agents"
|
||||
)
|
||||
assert (
|
||||
candidate.sampling_params.max_tokens is not None
|
||||
), "SamplingParams.max_tokens must be provided"
|
||||
|
@ -179,6 +179,24 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
else:
|
||||
raise ValueError("Invalid input row")
|
||||
|
||||
return generations
|
||||
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: EvalTaskConfig,
|
||||
) -> EvaluateResponse:
|
||||
candidate = task_config.eval_candidate
|
||||
if candidate.type == "agent":
|
||||
raise NotImplementedError(
|
||||
"Evaluation with generation has not been implemented for agents"
|
||||
)
|
||||
|
||||
if candidate.type == "model":
|
||||
generations = await self._run_model_generation(input_rows, task_config)
|
||||
|
||||
# scoring with generated_answer
|
||||
score_input_rows = [
|
||||
input_r | generated_r
|
||||
|
|
|
@ -22,6 +22,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
Api.datasets,
|
||||
Api.scoring,
|
||||
Api.inference,
|
||||
Api.agents,
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue