From d4887fc746689d028e6b6c8638824bb1a0242b86 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 24 Oct 2024 14:49:02 -0700 Subject: [PATCH] address comments --- llama_stack/apis/common/job_types.py | 1 - llama_stack/apis/eval/eval.py | 2 +- .../providers/impls/meta_reference/eval/eval.py | 11 ++++++----- llama_stack/providers/tests/eval/test_eval.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/llama_stack/apis/common/job_types.py b/llama_stack/apis/common/job_types.py index 3161e3e87..ab8ab22dc 100644 --- a/llama_stack/apis/common/job_types.py +++ b/llama_stack/apis/common/job_types.py @@ -18,4 +18,3 @@ class Job(BaseModel): class JobStatus(Enum): completed = "completed" in_progress = "in_progress" - not_found = "not_found" diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index cfcada766..51f49da15 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -61,7 +61,7 @@ class Eval(Protocol): ) -> EvaluateResponse: ... @webmethod(route="/eval/job/status", method="GET") - async def job_status(self, job_id: str) -> JobStatus: ... + async def job_status(self, job_id: str) -> Optional[JobStatus]: ... @webmethod(route="/eval/job/cancel", method="POST") async def job_cancel(self, job_id: str) -> None: ... diff --git a/llama_stack/providers/impls/meta_reference/eval/eval.py b/llama_stack/providers/impls/meta_reference/eval/eval.py index abf187bab..e5e2bcdc0 100644 --- a/llama_stack/providers/impls/meta_reference/eval/eval.py +++ b/llama_stack/providers/impls/meta_reference/eval/eval.py @@ -73,6 +73,8 @@ class MetaReferenceEvalImpl(Eval): scoring_functions=scoring_functions, ) + # TODO: currently needs to wait for generation before returning + # need job scheduler queue (ray/celery) w/ jobs api job_id = str(len(self.jobs)) self.jobs[job_id] = res return Job(job_id=job_id) @@ -99,6 +101,7 @@ class MetaReferenceEvalImpl(Eval): response = await self.inference_api.chat_completion( model=candidate.model, messages=messages, + sampling_params=candidate.sampling_params, ) generations.append( {"generated_answer": response.completion_message.content} @@ -116,18 +119,16 @@ class MetaReferenceEvalImpl(Eval): return EvaluateResponse(generations=generations, scores=score_response.results) - async def job_status(self, job_id: str) -> JobStatus: + async def job_status(self, job_id: str) -> Optional[JobStatus]: if job_id in self.jobs: return JobStatus.completed - else: - return JobStatus.not_found async def job_cancel(self, job_id: str) -> None: raise NotImplementedError("Job cancel is not implemented yet") - async def job_result(self, job_id: str) -> None: + async def job_result(self, job_id: str) -> EvaluateResponse: status = await self.job_status(job_id) - if status != JobStatus.completed: + if not status or status != JobStatus.completed: raise ValueError(f"Job is not completed, Status: {status.value}") return self.jobs[job_id] diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 099153f03..3a1ca169b 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -70,7 +70,7 @@ async def test_eval(eval_settings): assert response.job_id == "0" job_status = await eval_impl.job_status(response.job_id) - assert job_status.value == "completed" + assert job_status and job_status.value == "completed" eval_response = await eval_impl.job_result(response.job_id)