This commit is contained in:
Xi Yan 2024-10-03 11:36:18 -07:00
parent 5e9301de90
commit 7143ecfc0d

View file

@ -50,6 +50,13 @@ class EvaluateTaskRequestCommon(BaseModel):
sampling_params: SamplingParams = SamplingParams() sampling_params: SamplingParams = SamplingParams()
@json_schema_type
class EvaluateResponse(BaseModel):
"""Scores for evaluation."""
scores = Dict[str, str]
@json_schema_type @json_schema_type
class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon): class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon):
"""Request to evaluate text generation.""" """Request to evaluate text generation."""
@ -91,30 +98,19 @@ class EvaluationJobCreateResponse(BaseModel):
class Evaluations(Protocol): class Evaluations(Protocol):
@webmethod(route="/evaluate/text_generation/") @webmethod(route="/evaluate")
def create_evaluation_job(self, model: str, dataset: str) -> EvaluationJob: ... async def evaluate(
self, model: str, dataset: str, task: str
# @webmethod(route="/evaluate/text_generation/") ) -> EvaluateResponse: ...
# def evaluate_text_generation(
# self,
# metrics: List[TextGenerationMetric],
# ) -> EvaluationJob: ...
# @webmethod(route="/evaluate/question_answering/")
# def evaluate_question_answering(
# self,
# metrics: List[QuestionAnsweringMetric],
# ) -> EvaluationJob: ...
# @webmethod(route="/evaluate/summarization/")
# def evaluate_summarization(
# self,
# metrics: List[SummarizationMetric],
# ) -> EvaluationJob: ...
@webmethod(route="/evaluate/jobs") @webmethod(route="/evaluate/jobs")
def get_evaluation_jobs(self) -> List[EvaluationJob]: ... def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
@webmethod(route="/evaluate/job/create")
async def create_evaluation_job(
self, model: str, dataset: str, task: str
) -> EvaluationJob: ...
@webmethod(route="/evaluate/job/status") @webmethod(route="/evaluate/job/status")
def get_evaluation_job_status( def get_evaluation_job_status(
self, job_uuid: str self, job_uuid: str