forked from phoenix-oss/llama-stack-mirror
* type system api * datasets api * fix * datasetio api * kill reward scoring * scoring functions + evals * move jobs, fix errors
70 lines
2 KiB
Python
70 lines
2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import Literal, Optional, Protocol, Union
|
|
|
|
from typing_extensions import Annotated
|
|
|
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
from llama_models.schema_utils import json_schema_type, webmethod
|
|
from llama_stack.apis.scoring_functions import * # noqa: F403
|
|
from llama_stack.apis.agents import AgentConfig
|
|
from llama_stack.apis.common.job_types import Job
|
|
from llama_stack.apis.scoring import * # noqa: F403
|
|
|
|
|
|
@json_schema_type
|
|
class ModelCandidate(BaseModel):
|
|
type: Literal["model"] = "model"
|
|
model: str
|
|
sampling_params: SamplingParams
|
|
system_message: Optional[SystemMessage] = None
|
|
|
|
|
|
@json_schema_type
|
|
class AgentCandidate(BaseModel):
|
|
type: Literal["agent"] = "agent"
|
|
config: AgentConfig
|
|
|
|
|
|
EvalCandidate = Annotated[
|
|
Union[ModelCandidate, AgentCandidate], Field(discriminator="type")
|
|
]
|
|
|
|
|
|
@json_schema_type
|
|
class EvaluateResponse(BaseModel):
|
|
generations: List[Dict[str, Any]]
|
|
|
|
# each key in the dict is a scoring function name
|
|
scores: List[Dict[str, ScoringResult]]
|
|
|
|
|
|
class Eval(Protocol):
|
|
@webmethod(route="/eval/evaluate_batch", method="POST")
|
|
async def evaluate_batch(
|
|
self,
|
|
dataset_id: str,
|
|
candidate: EvalCandidate,
|
|
scoring_functions: List[str],
|
|
) -> Job: ...
|
|
|
|
@webmethod(route="/eval/evaluate", method="POST")
|
|
async def evaluate(
|
|
self,
|
|
input_rows: List[Dict[str, Any]],
|
|
candidate: EvalCandidate,
|
|
scoring_functions: List[str],
|
|
) -> EvaluateResponse: ...
|
|
|
|
@webmethod(route="/eval/job/status", method="GET")
|
|
async def job_status(self, job_id: str) -> None: ...
|
|
|
|
@webmethod(route="/eval/job/cancel", method="POST")
|
|
async def job_cancel(self, job_id: str) -> None: ...
|
|
|
|
@webmethod(route="/eval/job/result", method="GET")
|
|
async def job_result(self, job_id: str) -> None: ...
|