mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-04 05:52:17 +00:00
jobs eval scoring
This commit is contained in:
parent
3a75799900
commit
7b4f7888f1
5 changed files with 1537 additions and 939 deletions
|
|
@ -4,10 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.job_types import CommonJobFields, JobType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
|
@ -47,6 +48,27 @@ class ScoreResponse(BaseModel):
|
|||
results: Dict[str, ScoringResult]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoringJob(CommonJobFields):
|
||||
type: Literal["scoring"] = "scoring"
|
||||
|
||||
result_files: List[str] = Field(
|
||||
description="The file ids of the scoring results.",
|
||||
default_factory=list,
|
||||
)
|
||||
result_datasets: List[str] = Field(
|
||||
description="The ids of the datasets containing the scoring results.",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
# how the job is created
|
||||
dataset_id: str = Field(description="The id of the dataset used for scoring.")
|
||||
scoring_fn_ids: List[str] = Field(
|
||||
description="The ids of the scoring functions used.",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
class ScoringFunctionStore(Protocol):
|
||||
def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
|
||||
|
||||
|
|
@ -60,7 +82,7 @@ class Scoring(Protocol):
|
|||
self,
|
||||
dataset_id: str,
|
||||
scoring_fn_ids: List[str],
|
||||
) -> ScoreBatchResponse: ...
|
||||
) -> ScoringJob: ...
|
||||
|
||||
@webmethod(route="/scoring/rows", method="POST")
|
||||
async def score(
|
||||
|
|
@ -75,3 +97,36 @@ class Scoring(Protocol):
|
|||
:return: ScoreResponse object containing rows and aggregated results
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring/jobs", method="GET")
|
||||
async def list_scoring_jobs(self) -> List[ScoringJob]:
|
||||
"""List all scoring jobs.
|
||||
|
||||
:return: A list of scoring jobs.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring/job/{job_id}", method="GET")
|
||||
async def get_scoring_job(self, job_id: str) -> Optional[ScoringJob]:
|
||||
"""Get a job by id.
|
||||
|
||||
:param job_id: The id of the job to get.
|
||||
:return: The job.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring/job/{job_id}", method="DELETE")
|
||||
async def delete_scoring_job(self, job_id: str) -> Optional[ScoringJob]:
|
||||
"""Delete a job.
|
||||
|
||||
:param job_id: The id of the job to delete.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring/job/{job_id}/cancel", method="POST")
|
||||
async def cancel_scoring_job(self, job_id: str) -> Optional[ScoringJob]:
|
||||
"""Cancel a job.
|
||||
|
||||
:param job_id: The id of the job to cancel.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue