mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
address comments
This commit is contained in:
parent
ba0186f2c8
commit
d4887fc746
4 changed files with 8 additions and 8 deletions
|
@ -18,4 +18,3 @@ class Job(BaseModel):
|
|||
class JobStatus(Enum):
|
||||
completed = "completed"
|
||||
in_progress = "in_progress"
|
||||
not_found = "not_found"
|
||||
|
|
|
@ -61,7 +61,7 @@ class Eval(Protocol):
|
|||
) -> EvaluateResponse: ...
|
||||
|
||||
@webmethod(route="/eval/job/status", method="GET")
|
||||
async def job_status(self, job_id: str) -> JobStatus: ...
|
||||
async def job_status(self, job_id: str) -> Optional[JobStatus]: ...
|
||||
|
||||
@webmethod(route="/eval/job/cancel", method="POST")
|
||||
async def job_cancel(self, job_id: str) -> None: ...
|
||||
|
|
|
@ -73,6 +73,8 @@ class MetaReferenceEvalImpl(Eval):
|
|||
scoring_functions=scoring_functions,
|
||||
)
|
||||
|
||||
# TODO: currently needs to wait for generation before returning
|
||||
# need job scheduler queue (ray/celery) w/ jobs api
|
||||
job_id = str(len(self.jobs))
|
||||
self.jobs[job_id] = res
|
||||
return Job(job_id=job_id)
|
||||
|
@ -99,6 +101,7 @@ class MetaReferenceEvalImpl(Eval):
|
|||
response = await self.inference_api.chat_completion(
|
||||
model=candidate.model,
|
||||
messages=messages,
|
||||
sampling_params=candidate.sampling_params,
|
||||
)
|
||||
generations.append(
|
||||
{"generated_answer": response.completion_message.content}
|
||||
|
@ -116,18 +119,16 @@ class MetaReferenceEvalImpl(Eval):
|
|||
|
||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||
|
||||
async def job_status(self, job_id: str) -> JobStatus:
|
||||
async def job_status(self, job_id: str) -> Optional[JobStatus]:
|
||||
if job_id in self.jobs:
|
||||
return JobStatus.completed
|
||||
else:
|
||||
return JobStatus.not_found
|
||||
|
||||
async def job_cancel(self, job_id: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
|
||||
async def job_result(self, job_id: str) -> None:
|
||||
async def job_result(self, job_id: str) -> EvaluateResponse:
|
||||
status = await self.job_status(job_id)
|
||||
if status != JobStatus.completed:
|
||||
if not status or status != JobStatus.completed:
|
||||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||
|
||||
return self.jobs[job_id]
|
||||
|
|
|
@ -70,7 +70,7 @@ async def test_eval(eval_settings):
|
|||
assert response.job_id == "0"
|
||||
job_status = await eval_impl.job_status(response.job_id)
|
||||
|
||||
assert job_status.value == "completed"
|
||||
assert job_status and job_status.value == "completed"
|
||||
|
||||
eval_response = await eval_impl.job_result(response.job_id)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue