From 2f140c7ccf14698277a58f557020fd05f7e9ea05 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 21 Mar 2025 13:41:02 -0700 Subject: [PATCH] precommit --- docs/_static/llama-stack-spec.html | 58 ++++--- docs/_static/llama-stack-spec.yaml | 45 +++--- llama_stack/apis/common/job_types.py | 12 +- llama_stack/apis/eval/eval.py | 4 +- llama_stack/distribution/routers/routers.py | 145 +++++++++++++----- .../inline/eval/meta_reference/eval.py | 85 +++++++--- tests/integration/eval/test_eval.py | 2 +- 7 files changed, 235 insertions(+), 116 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index c81f9b33d..8a46a89ad 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -2183,7 +2183,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/JobStatus" + "$ref": "#/components/schemas/Job" } } } @@ -7648,16 +7648,6 @@ "title": "PostTrainingJobArtifactsResponse", "description": "Artifacts of a finetuning job." }, - "JobStatus": { - "type": "string", - "enum": [ - "completed", - "in_progress", - "failed", - "scheduled" - ], - "title": "JobStatus" - }, "PostTrainingJobStatusResponse": { "type": "object", "properties": { @@ -7665,7 +7655,14 @@ "type": "string" }, "status": { - "$ref": "#/components/schemas/JobStatus" + "type": "string", + "enum": [ + "completed", + "in_progress", + "failed", + "scheduled" + ], + "title": "JobStatus" }, "scheduled_at": { "type": "string", @@ -8115,6 +8112,30 @@ "title": "IterrowsResponse", "description": "A paginated list of rows from a dataset." }, + "Job": { + "type": "object", + "properties": { + "job_id": { + "type": "string" + }, + "status": { + "type": "string", + "enum": [ + "completed", + "in_progress", + "failed", + "scheduled" + ], + "title": "JobStatus" + } + }, + "additionalProperties": false, + "required": [ + "job_id", + "status" + ], + "title": "Job" + }, "ListAgentSessionsResponse": { "type": "object", "properties": { @@ -9639,19 +9660,6 @@ ], "title": "RunEvalRequest" }, - "Job": { - "type": "object", - "properties": { - "job_id": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "job_id" - ], - "title": "Job" - }, "RunShieldRequest": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 8ea0e1b9c..0b8f90490 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -1491,7 +1491,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/JobStatus' + $ref: '#/components/schemas/Job' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -5277,21 +5277,19 @@ components: - checkpoints title: PostTrainingJobArtifactsResponse description: Artifacts of a finetuning job. - JobStatus: - type: string - enum: - - completed - - in_progress - - failed - - scheduled - title: JobStatus PostTrainingJobStatusResponse: type: object properties: job_uuid: type: string status: - $ref: '#/components/schemas/JobStatus' + type: string + enum: + - completed + - in_progress + - failed + - scheduled + title: JobStatus scheduled_at: type: string format: date-time @@ -5556,6 +5554,24 @@ components: - data title: IterrowsResponse description: A paginated list of rows from a dataset. + Job: + type: object + properties: + job_id: + type: string + status: + type: string + enum: + - completed + - in_progress + - failed + - scheduled + title: JobStatus + additionalProperties: false + required: + - job_id + - status + title: Job ListAgentSessionsResponse: type: object properties: @@ -6550,15 +6566,6 @@ components: required: - benchmark_config title: RunEvalRequest - Job: - type: object - properties: - job_id: - type: string - additionalProperties: false - required: - - job_id - title: Job RunShieldRequest: type: object properties: diff --git a/llama_stack/apis/common/job_types.py b/llama_stack/apis/common/job_types.py index bc070017b..9acecc154 100644 --- a/llama_stack/apis/common/job_types.py +++ b/llama_stack/apis/common/job_types.py @@ -10,14 +10,14 @@ from pydantic import BaseModel from llama_stack.schema_utils import json_schema_type -@json_schema_type -class Job(BaseModel): - job_id: str - - -@json_schema_type class JobStatus(Enum): completed = "completed" in_progress = "in_progress" failed = "failed" scheduled = "scheduled" + + +@json_schema_type +class Job(BaseModel): + job_id: str + status: JobStatus diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index d05786321..0e5959c37 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -10,7 +10,7 @@ from pydantic import BaseModel, Field from typing_extensions import Annotated from llama_stack.apis.agents import AgentConfig -from llama_stack.apis.common.job_types import Job, JobStatus +from llama_stack.apis.common.job_types import Job from llama_stack.apis.inference import SamplingParams, SystemMessage from llama_stack.apis.scoring import ScoringResult from llama_stack.apis.scoring_functions import ScoringFnParams @@ -115,7 +115,7 @@ class Eval(Protocol): """ @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET") - async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus: + async def job_status(self, benchmark_id: str, job_id: str) -> Job: """Get the status of a job. :param benchmark_id: The ID of the benchmark to run the evaluation on. diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 2cf38f544..1a98888a2 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -8,19 +8,13 @@ import time from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from llama_stack.apis.common.content_types import ( - URL, InterleavedContent, InterleavedContentItem, + URL, ) from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import DatasetPurpose, DataSource -from llama_stack.apis.eval import ( - BenchmarkConfig, - Eval, - EvaluateResponse, - Job, - JobStatus, -) +from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseEventType, @@ -94,7 +88,9 @@ class VectorIORouter(VectorIO): provider_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None, ) -> None: - logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") + logger.debug( + f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}" + ) await self.routing_table.register_vector_db( vector_db_id, embedding_model, @@ -112,7 +108,9 @@ class VectorIORouter(VectorIO): logger.debug( f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", ) - return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) + return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks( + vector_db_id, chunks, ttl_seconds + ) async def query_chunks( self, @@ -121,7 +119,9 @@ class VectorIORouter(VectorIO): params: Optional[Dict[str, Any]] = None, ) -> QueryChunksResponse: logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") - return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) + return await self.routing_table.get_provider_impl(vector_db_id).query_chunks( + vector_db_id, query, params + ) class InferenceRouter(Inference): @@ -158,7 +158,9 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", ) - await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) + await self.routing_table.register_model( + model_id, provider_model_id, provider_id, metadata, model_type + ) def _construct_metrics( self, @@ -212,11 +214,16 @@ class InferenceRouter(Inference): total_tokens: int, model: Model, ) -> List[MetricInResponse]: - metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) + metrics = self._construct_metrics( + prompt_tokens, completion_tokens, total_tokens, model + ) if self.telemetry: for metric in metrics: await self.telemetry.log_event(metric) - return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] + return [ + MetricInResponse(metric=metric.metric, value=metric.value) + for metric in metrics + ] async def _count_tokens( self, @@ -241,7 +248,9 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + ) -> Union[ + ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] + ]: logger.debug( f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) @@ -251,12 +260,19 @@ class InferenceRouter(Inference): if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: - raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") + raise ValueError( + f"Model '{model_id}' is an embedding model and does not support chat completions" + ) if tool_config: if tool_choice and tool_choice != tool_config.tool_choice: raise ValueError("tool_choice and tool_config.tool_choice must match") - if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format: - raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") + if ( + tool_prompt_format + and tool_prompt_format != tool_config.tool_prompt_format + ): + raise ValueError( + "tool_prompt_format and tool_config.tool_prompt_format must match" + ) else: params = {} if tool_choice: @@ -274,9 +290,14 @@ class InferenceRouter(Inference): pass else: # verify tool_choice is one of the tools - tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] + tool_names = [ + t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value + for t in tools + ] if tool_config.tool_choice not in tool_names: - raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") + raise ValueError( + f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}" + ) params = dict( model_id=model_id, @@ -291,17 +312,25 @@ class InferenceRouter(Inference): tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) - prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) + prompt_tokens = await self._count_tokens( + messages, tool_config.tool_prompt_format + ) if stream: async def stream_generator(): completion_text = "" async for chunk in await provider.chat_completion(**params): - if chunk.event.event_type == ChatCompletionResponseEventType.progress: + if ( + chunk.event.event_type + == ChatCompletionResponseEventType.progress + ): if chunk.event.delta.type == "text": completion_text += chunk.event.delta.text - if chunk.event.event_type == ChatCompletionResponseEventType.complete: + if ( + chunk.event.event_type + == ChatCompletionResponseEventType.complete + ): completion_tokens = await self._count_tokens( [ CompletionMessage( @@ -318,7 +347,11 @@ class InferenceRouter(Inference): total_tokens, model, ) - chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics + chunk.metrics = ( + metrics + if chunk.metrics is None + else chunk.metrics + metrics + ) yield chunk return stream_generator() @@ -335,7 +368,9 @@ class InferenceRouter(Inference): total_tokens, model, ) - response.metrics = metrics if response.metrics is None else response.metrics + metrics + response.metrics = ( + metrics if response.metrics is None else response.metrics + metrics + ) return response async def completion( @@ -356,7 +391,9 @@ class InferenceRouter(Inference): if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: - raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") + raise ValueError( + f"Model '{model_id}' is an embedding model and does not support chat completions" + ) provider = self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, @@ -376,7 +413,11 @@ class InferenceRouter(Inference): async for chunk in await provider.completion(**params): if hasattr(chunk, "delta"): completion_text += chunk.delta - if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: + if ( + hasattr(chunk, "stop_reason") + and chunk.stop_reason + and self.telemetry + ): completion_tokens = await self._count_tokens(completion_text) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) metrics = await self._compute_and_log_token_usage( @@ -385,7 +426,11 @@ class InferenceRouter(Inference): total_tokens, model, ) - chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics + chunk.metrics = ( + metrics + if chunk.metrics is None + else chunk.metrics + metrics + ) yield chunk return stream_generator() @@ -399,7 +444,9 @@ class InferenceRouter(Inference): total_tokens, model, ) - response.metrics = metrics if response.metrics is None else response.metrics + metrics + response.metrics = ( + metrics if response.metrics is None else response.metrics + metrics + ) return response async def embeddings( @@ -415,7 +462,9 @@ class InferenceRouter(Inference): if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.llm: - raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") + raise ValueError( + f"Model '{model_id}' is an LLM model and does not support embeddings" + ) return await self.routing_table.get_provider_impl(model_id).embeddings( model_id=model_id, contents=contents, @@ -449,7 +498,9 @@ class SafetyRouter(Safety): params: Optional[Dict[str, Any]] = None, ) -> Shield: logger.debug(f"SafetyRouter.register_shield: {shield_id}") - return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) + return await self.routing_table.register_shield( + shield_id, provider_shield_id, provider_id, params + ) async def run_shield( self, @@ -546,7 +597,9 @@ class ScoringRouter(Scoring): logger.debug(f"ScoringRouter.score_batch: {dataset_id}") res = {} for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( + score_response = await self.routing_table.get_provider_impl( + fn_identifier + ).score_batch( dataset_id=dataset_id, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -564,11 +617,15 @@ class ScoringRouter(Scoring): input_rows: List[Dict[str, Any]], scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: - logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") + logger.debug( + f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions" + ) res = {} # look up and map each scoring function to its provider impl for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score( + score_response = await self.routing_table.get_provider_impl( + fn_identifier + ).score( input_rows=input_rows, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -611,7 +668,9 @@ class EvalRouter(Eval): scoring_functions: List[str], benchmark_config: BenchmarkConfig, ) -> EvaluateResponse: - logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") + logger.debug( + f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows" + ) return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( benchmark_id=benchmark_id, input_rows=input_rows, @@ -623,9 +682,11 @@ class EvalRouter(Eval): self, benchmark_id: str, job_id: str, - ) -> Optional[JobStatus]: + ) -> Job: logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) + return await self.routing_table.get_provider_impl(benchmark_id).job_status( + benchmark_id, job_id + ) async def job_cancel( self, @@ -679,9 +740,9 @@ class ToolRuntimeRouter(ToolRuntime): logger.debug( f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" ) - return await self.routing_table.get_provider_impl("insert_into_memory").insert( - documents, vector_db_id, chunk_size_in_tokens - ) + return await self.routing_table.get_provider_impl( + "insert_into_memory" + ).insert(documents, vector_db_id, chunk_size_in_tokens) def __init__( self, @@ -714,4 +775,6 @@ class ToolRuntimeRouter(ToolRuntime): self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") - return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) + return await self.routing_table.get_provider_impl(tool_group_id).list_tools( + tool_group_id, mcp_endpoint + ) diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index 3630d4c03..84ed8437e 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import json -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from tqdm import tqdm @@ -21,8 +21,8 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( from llama_stack.providers.utils.common.data_schema_validator import ColumnName from llama_stack.providers.utils.kvstore import kvstore_impl -from .....apis.common.job_types import Job -from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse, JobStatus +from .....apis.common.job_types import Job, JobStatus +from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse from .config import MetaReferenceEvalConfig EVAL_TASKS_PREFIX = "benchmarks:" @@ -89,7 +89,11 @@ class MetaReferenceEvalImpl( all_rows = await self.datasetio_api.iterrows( dataset_id=dataset_id, - limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples), + limit=( + -1 + if benchmark_config.num_examples is None + else benchmark_config.num_examples + ), ) res = await self.evaluate_rows( benchmark_id=benchmark_id, @@ -102,7 +106,7 @@ class MetaReferenceEvalImpl( # need job scheduler queue (ray/celery) w/ jobs api job_id = str(len(self.jobs)) self.jobs[job_id] = res - return Job(job_id=job_id) + return Job(job_id=job_id, status=JobStatus.completed) async def _run_agent_generation( self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig @@ -115,10 +119,14 @@ class MetaReferenceEvalImpl( for i, x in tqdm(enumerate(input_rows)): assert ColumnName.chat_completion_input.value in x, "Invalid input row" input_messages = json.loads(x[ColumnName.chat_completion_input.value]) - input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"] + input_messages = [ + UserMessage(**x) for x in input_messages if x["role"] == "user" + ] # NOTE: only single-turn agent generation is supported. Create a new session for each input row - session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}") + session_create_response = await self.agents_api.create_agent_session( + agent_id, f"session-{i}" + ) session_id = session_create_response.session_id turn_request = dict( @@ -127,7 +135,12 @@ class MetaReferenceEvalImpl( messages=input_messages, stream=True, ) - turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)] + turn_response = [ + chunk + async for chunk in await self.agents_api.create_agent_turn( + **turn_request + ) + ] final_event = turn_response[-1].event.payload # check if there's a memory retrieval step and extract the context @@ -136,10 +149,14 @@ class MetaReferenceEvalImpl( if step.step_type == StepType.tool_execution.value: for tool_response in step.tool_responses: if tool_response.tool_name == MEMORY_QUERY_TOOL: - memory_rag_context = " ".join(x.text for x in tool_response.content) + memory_rag_context = " ".join( + x.text for x in tool_response.content + ) agent_generation = {} - agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content + agent_generation[ColumnName.generated_answer.value] = ( + final_event.turn.output_message.content + ) if memory_rag_context: agent_generation[ColumnName.context.value] = memory_rag_context @@ -151,7 +168,9 @@ class MetaReferenceEvalImpl( self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig ) -> List[Dict[str, Any]]: candidate = benchmark_config.eval_candidate - assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided" + assert ( + candidate.sampling_params.max_tokens is not None + ), "SamplingParams.max_tokens must be provided" generations = [] for x in tqdm(input_rows): @@ -162,21 +181,39 @@ class MetaReferenceEvalImpl( content=input_content, sampling_params=candidate.sampling_params, ) - generations.append({ColumnName.generated_answer.value: response.completion_message.content}) + generations.append( + { + ColumnName.generated_answer.value: response.completion_message.content + } + ) elif ColumnName.chat_completion_input.value in x: - chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value]) - input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"] + chat_completion_input_json = json.loads( + x[ColumnName.chat_completion_input.value] + ) + input_messages = [ + UserMessage(**x) + for x in chat_completion_input_json + if x["role"] == "user" + ] messages = [] if candidate.system_message: messages.append(candidate.system_message) - messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"] + messages += [ + SystemMessage(**x) + for x in chat_completion_input_json + if x["role"] == "system" + ] messages += input_messages response = await self.inference_api.chat_completion( model_id=candidate.model, messages=messages, sampling_params=candidate.sampling_params, ) - generations.append({ColumnName.generated_answer.value: response.completion_message.content}) + generations.append( + { + ColumnName.generated_answer.value: response.completion_message.content + } + ) else: raise ValueError("Invalid input row") @@ -199,7 +236,8 @@ class MetaReferenceEvalImpl( # scoring with generated_answer score_input_rows = [ - input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False) + input_r | generated_r + for input_r, generated_r in zip(input_rows, generations, strict=False) ] if benchmark_config.scoring_params is not None: @@ -208,7 +246,9 @@ class MetaReferenceEvalImpl( for scoring_fn_id in scoring_functions } else: - scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions} + scoring_functions_dict = { + scoring_fn_id: None for scoring_fn_id in scoring_functions + } score_response = await self.scoring_api.score( input_rows=score_input_rows, scoring_functions=scoring_functions_dict @@ -216,17 +256,18 @@ class MetaReferenceEvalImpl( return EvaluateResponse(generations=generations, scores=score_response.results) - async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: + async def job_status(self, benchmark_id: str, job_id: str) -> Job: if job_id in self.jobs: - return JobStatus.completed + return Job(job_id=job_id, status=JobStatus.completed) - return None + raise ValueError(f"Job {job_id} not found") async def job_cancel(self, benchmark_id: str, job_id: str) -> None: raise NotImplementedError("Job cancel is not implemented yet") async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: - status = await self.job_status(benchmark_id, job_id) + job = await self.job_status(benchmark_id, job_id) + status = job.status if not status or status != JobStatus.completed: raise ValueError(f"Job is not completed, Status: {status.value}") diff --git a/tests/integration/eval/test_eval.py b/tests/integration/eval/test_eval.py index c4aa0fa1b..d1c3de519 100644 --- a/tests/integration/eval/test_eval.py +++ b/tests/integration/eval/test_eval.py @@ -94,7 +94,7 @@ def test_evaluate_benchmark(llama_stack_client, text_model_id, scoring_fn_id): ) assert response.job_id == "0" job_status = llama_stack_client.eval.jobs.status(job_id=response.job_id, benchmark_id=benchmark_id) - assert job_status and job_status == "completed" + assert job_status and job_status.status == "completed" eval_response = llama_stack_client.eval.jobs.retrieve(job_id=response.job_id, benchmark_id=benchmark_id) assert eval_response is not None