address comments

This commit is contained in:
Xi Yan 2024-10-24 14:49:02 -07:00
parent ba0186f2c8
commit d4887fc746
4 changed files with 8 additions and 8 deletions

View file

@ -18,4 +18,3 @@ class Job(BaseModel):
class JobStatus(Enum): class JobStatus(Enum):
completed = "completed" completed = "completed"
in_progress = "in_progress" in_progress = "in_progress"
not_found = "not_found"

View file

@ -61,7 +61,7 @@ class Eval(Protocol):
) -> EvaluateResponse: ... ) -> EvaluateResponse: ...
@webmethod(route="/eval/job/status", method="GET") @webmethod(route="/eval/job/status", method="GET")
async def job_status(self, job_id: str) -> JobStatus: ... async def job_status(self, job_id: str) -> Optional[JobStatus]: ...
@webmethod(route="/eval/job/cancel", method="POST") @webmethod(route="/eval/job/cancel", method="POST")
async def job_cancel(self, job_id: str) -> None: ... async def job_cancel(self, job_id: str) -> None: ...

View file

@ -73,6 +73,8 @@ class MetaReferenceEvalImpl(Eval):
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
) )
# TODO: currently needs to wait for generation before returning
# need job scheduler queue (ray/celery) w/ jobs api
job_id = str(len(self.jobs)) job_id = str(len(self.jobs))
self.jobs[job_id] = res self.jobs[job_id] = res
return Job(job_id=job_id) return Job(job_id=job_id)
@ -99,6 +101,7 @@ class MetaReferenceEvalImpl(Eval):
response = await self.inference_api.chat_completion( response = await self.inference_api.chat_completion(
model=candidate.model, model=candidate.model,
messages=messages, messages=messages,
sampling_params=candidate.sampling_params,
) )
generations.append( generations.append(
{"generated_answer": response.completion_message.content} {"generated_answer": response.completion_message.content}
@ -116,18 +119,16 @@ class MetaReferenceEvalImpl(Eval):
return EvaluateResponse(generations=generations, scores=score_response.results) return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, job_id: str) -> JobStatus: async def job_status(self, job_id: str) -> Optional[JobStatus]:
if job_id in self.jobs: if job_id in self.jobs:
return JobStatus.completed return JobStatus.completed
else:
return JobStatus.not_found
async def job_cancel(self, job_id: str) -> None: async def job_cancel(self, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet") raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, job_id: str) -> None: async def job_result(self, job_id: str) -> EvaluateResponse:
status = await self.job_status(job_id) status = await self.job_status(job_id)
if status != JobStatus.completed: if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}") raise ValueError(f"Job is not completed, Status: {status.value}")
return self.jobs[job_id] return self.jobs[job_id]

View file

@ -70,7 +70,7 @@ async def test_eval(eval_settings):
assert response.job_id == "0" assert response.job_id == "0"
job_status = await eval_impl.job_status(response.job_id) job_status = await eval_impl.job_status(response.job_id)
assert job_status.value == "completed" assert job_status and job_status.value == "completed"
eval_response = await eval_impl.job_result(response.job_id) eval_response = await eval_impl.job_result(response.job_id)