diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index c81f9b33d..8a46a89ad 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -2183,7 +2183,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/JobStatus"
+ "$ref": "#/components/schemas/Job"
}
}
}
@@ -7648,16 +7648,6 @@
"title": "PostTrainingJobArtifactsResponse",
"description": "Artifacts of a finetuning job."
},
- "JobStatus": {
- "type": "string",
- "enum": [
- "completed",
- "in_progress",
- "failed",
- "scheduled"
- ],
- "title": "JobStatus"
- },
"PostTrainingJobStatusResponse": {
"type": "object",
"properties": {
@@ -7665,7 +7655,14 @@
"type": "string"
},
"status": {
- "$ref": "#/components/schemas/JobStatus"
+ "type": "string",
+ "enum": [
+ "completed",
+ "in_progress",
+ "failed",
+ "scheduled"
+ ],
+ "title": "JobStatus"
},
"scheduled_at": {
"type": "string",
@@ -8115,6 +8112,30 @@
"title": "IterrowsResponse",
"description": "A paginated list of rows from a dataset."
},
+ "Job": {
+ "type": "object",
+ "properties": {
+ "job_id": {
+ "type": "string"
+ },
+ "status": {
+ "type": "string",
+ "enum": [
+ "completed",
+ "in_progress",
+ "failed",
+ "scheduled"
+ ],
+ "title": "JobStatus"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "job_id",
+ "status"
+ ],
+ "title": "Job"
+ },
"ListAgentSessionsResponse": {
"type": "object",
"properties": {
@@ -9639,19 +9660,6 @@
],
"title": "RunEvalRequest"
},
- "Job": {
- "type": "object",
- "properties": {
- "job_id": {
- "type": "string"
- }
- },
- "additionalProperties": false,
- "required": [
- "job_id"
- ],
- "title": "Job"
- },
"RunShieldRequest": {
"type": "object",
"properties": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 8ea0e1b9c..0b8f90490 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -1491,7 +1491,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/JobStatus'
+ $ref: '#/components/schemas/Job'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@@ -5277,21 +5277,19 @@ components:
- checkpoints
title: PostTrainingJobArtifactsResponse
description: Artifacts of a finetuning job.
- JobStatus:
- type: string
- enum:
- - completed
- - in_progress
- - failed
- - scheduled
- title: JobStatus
PostTrainingJobStatusResponse:
type: object
properties:
job_uuid:
type: string
status:
- $ref: '#/components/schemas/JobStatus'
+ type: string
+ enum:
+ - completed
+ - in_progress
+ - failed
+ - scheduled
+ title: JobStatus
scheduled_at:
type: string
format: date-time
@@ -5556,6 +5554,24 @@ components:
- data
title: IterrowsResponse
description: A paginated list of rows from a dataset.
+ Job:
+ type: object
+ properties:
+ job_id:
+ type: string
+ status:
+ type: string
+ enum:
+ - completed
+ - in_progress
+ - failed
+ - scheduled
+ title: JobStatus
+ additionalProperties: false
+ required:
+ - job_id
+ - status
+ title: Job
ListAgentSessionsResponse:
type: object
properties:
@@ -6550,15 +6566,6 @@ components:
required:
- benchmark_config
title: RunEvalRequest
- Job:
- type: object
- properties:
- job_id:
- type: string
- additionalProperties: false
- required:
- - job_id
- title: Job
RunShieldRequest:
type: object
properties:
diff --git a/llama_stack/apis/common/job_types.py b/llama_stack/apis/common/job_types.py
index bc070017b..9acecc154 100644
--- a/llama_stack/apis/common/job_types.py
+++ b/llama_stack/apis/common/job_types.py
@@ -10,14 +10,14 @@ from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
-@json_schema_type
-class Job(BaseModel):
- job_id: str
-
-
-@json_schema_type
class JobStatus(Enum):
completed = "completed"
in_progress = "in_progress"
failed = "failed"
scheduled = "scheduled"
+
+
+@json_schema_type
+class Job(BaseModel):
+ job_id: str
+ status: JobStatus
diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py
index d05786321..0e5959c37 100644
--- a/llama_stack/apis/eval/eval.py
+++ b/llama_stack/apis/eval/eval.py
@@ -10,7 +10,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig
-from llama_stack.apis.common.job_types import Job, JobStatus
+from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
@@ -115,7 +115,7 @@ class Eval(Protocol):
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
- async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus:
+ async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py
index 2cf38f544..1a98888a2 100644
--- a/llama_stack/distribution/routers/routers.py
+++ b/llama_stack/distribution/routers/routers.py
@@ -8,19 +8,13 @@ import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack.apis.common.content_types import (
- URL,
InterleavedContent,
InterleavedContentItem,
+ URL,
)
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource
-from llama_stack.apis.eval import (
- BenchmarkConfig,
- Eval,
- EvaluateResponse,
- Job,
- JobStatus,
-)
+from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
@@ -94,7 +88,9 @@ class VectorIORouter(VectorIO):
provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None,
) -> None:
- logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
+ logger.debug(
+ f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}"
+ )
await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
@@ -112,7 +108,9 @@ class VectorIORouter(VectorIO):
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
- return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
+ return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(
+ vector_db_id, chunks, ttl_seconds
+ )
async def query_chunks(
self,
@@ -121,7 +119,9 @@ class VectorIORouter(VectorIO):
params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
- return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
+ return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(
+ vector_db_id, query, params
+ )
class InferenceRouter(Inference):
@@ -158,7 +158,9 @@ class InferenceRouter(Inference):
logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
- await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
+ await self.routing_table.register_model(
+ model_id, provider_model_id, provider_id, metadata, model_type
+ )
def _construct_metrics(
self,
@@ -212,11 +214,16 @@ class InferenceRouter(Inference):
total_tokens: int,
model: Model,
) -> List[MetricInResponse]:
- metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
+ metrics = self._construct_metrics(
+ prompt_tokens, completion_tokens, total_tokens, model
+ )
if self.telemetry:
for metric in metrics:
await self.telemetry.log_event(metric)
- return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
+ return [
+ MetricInResponse(metric=metric.metric, value=metric.value)
+ for metric in metrics
+ ]
async def _count_tokens(
self,
@@ -241,7 +248,9 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
- ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
+ ) -> Union[
+ ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
+ ]:
logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
@@ -251,12 +260,19 @@ class InferenceRouter(Inference):
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding:
- raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
+ raise ValueError(
+ f"Model '{model_id}' is an embedding model and does not support chat completions"
+ )
if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match")
- if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
- raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
+ if (
+ tool_prompt_format
+ and tool_prompt_format != tool_config.tool_prompt_format
+ ):
+ raise ValueError(
+ "tool_prompt_format and tool_config.tool_prompt_format must match"
+ )
else:
params = {}
if tool_choice:
@@ -274,9 +290,14 @@ class InferenceRouter(Inference):
pass
else:
# verify tool_choice is one of the tools
- tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
+ tool_names = [
+ t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
+ for t in tools
+ ]
if tool_config.tool_choice not in tool_names:
- raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
+ raise ValueError(
+ f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
+ )
params = dict(
model_id=model_id,
@@ -291,17 +312,25 @@ class InferenceRouter(Inference):
tool_config=tool_config,
)
provider = self.routing_table.get_provider_impl(model_id)
- prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
+ prompt_tokens = await self._count_tokens(
+ messages, tool_config.tool_prompt_format
+ )
if stream:
async def stream_generator():
completion_text = ""
async for chunk in await provider.chat_completion(**params):
- if chunk.event.event_type == ChatCompletionResponseEventType.progress:
+ if (
+ chunk.event.event_type
+ == ChatCompletionResponseEventType.progress
+ ):
if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text
- if chunk.event.event_type == ChatCompletionResponseEventType.complete:
+ if (
+ chunk.event.event_type
+ == ChatCompletionResponseEventType.complete
+ ):
completion_tokens = await self._count_tokens(
[
CompletionMessage(
@@ -318,7 +347,11 @@ class InferenceRouter(Inference):
total_tokens,
model,
)
- chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
+ chunk.metrics = (
+ metrics
+ if chunk.metrics is None
+ else chunk.metrics + metrics
+ )
yield chunk
return stream_generator()
@@ -335,7 +368,9 @@ class InferenceRouter(Inference):
total_tokens,
model,
)
- response.metrics = metrics if response.metrics is None else response.metrics + metrics
+ response.metrics = (
+ metrics if response.metrics is None else response.metrics + metrics
+ )
return response
async def completion(
@@ -356,7 +391,9 @@ class InferenceRouter(Inference):
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding:
- raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
+ raise ValueError(
+ f"Model '{model_id}' is an embedding model and does not support chat completions"
+ )
provider = self.routing_table.get_provider_impl(model_id)
params = dict(
model_id=model_id,
@@ -376,7 +413,11 @@ class InferenceRouter(Inference):
async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"):
completion_text += chunk.delta
- if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
+ if (
+ hasattr(chunk, "stop_reason")
+ and chunk.stop_reason
+ and self.telemetry
+ ):
completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
@@ -385,7 +426,11 @@ class InferenceRouter(Inference):
total_tokens,
model,
)
- chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
+ chunk.metrics = (
+ metrics
+ if chunk.metrics is None
+ else chunk.metrics + metrics
+ )
yield chunk
return stream_generator()
@@ -399,7 +444,9 @@ class InferenceRouter(Inference):
total_tokens,
model,
)
- response.metrics = metrics if response.metrics is None else response.metrics + metrics
+ response.metrics = (
+ metrics if response.metrics is None else response.metrics + metrics
+ )
return response
async def embeddings(
@@ -415,7 +462,9 @@ class InferenceRouter(Inference):
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm:
- raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
+ raise ValueError(
+ f"Model '{model_id}' is an LLM model and does not support embeddings"
+ )
return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id,
contents=contents,
@@ -449,7 +498,9 @@ class SafetyRouter(Safety):
params: Optional[Dict[str, Any]] = None,
) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
- return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
+ return await self.routing_table.register_shield(
+ shield_id, provider_shield_id, provider_id, params
+ )
async def run_shield(
self,
@@ -546,7 +597,9 @@ class ScoringRouter(Scoring):
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
- score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
+ score_response = await self.routing_table.get_provider_impl(
+ fn_identifier
+ ).score_batch(
dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
@@ -564,11 +617,15 @@ class ScoringRouter(Scoring):
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
- logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
+ logger.debug(
+ f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions"
+ )
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
- score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
+ score_response = await self.routing_table.get_provider_impl(
+ fn_identifier
+ ).score(
input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
@@ -611,7 +668,9 @@ class EvalRouter(Eval):
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
- logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
+ logger.debug(
+ f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows"
+ )
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
@@ -623,9 +682,11 @@ class EvalRouter(Eval):
self,
benchmark_id: str,
job_id: str,
- ) -> Optional[JobStatus]:
+ ) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
- return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
+ return await self.routing_table.get_provider_impl(benchmark_id).job_status(
+ benchmark_id, job_id
+ )
async def job_cancel(
self,
@@ -679,9 +740,9 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
- return await self.routing_table.get_provider_impl("insert_into_memory").insert(
- documents, vector_db_id, chunk_size_in_tokens
- )
+ return await self.routing_table.get_provider_impl(
+ "insert_into_memory"
+ ).insert(documents, vector_db_id, chunk_size_in_tokens)
def __init__(
self,
@@ -714,4 +775,6 @@ class ToolRuntimeRouter(ToolRuntime):
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
- return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
+ return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
+ tool_group_id, mcp_endpoint
+ )
diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py
index 3630d4c03..84ed8437e 100644
--- a/llama_stack/providers/inline/eval/meta_reference/eval.py
+++ b/llama_stack/providers/inline/eval/meta_reference/eval.py
@@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List
from tqdm import tqdm
@@ -21,8 +21,8 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
-from .....apis.common.job_types import Job
-from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse, JobStatus
+from .....apis.common.job_types import Job, JobStatus
+from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "benchmarks:"
@@ -89,7 +89,11 @@ class MetaReferenceEvalImpl(
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
- limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
+ limit=(
+ -1
+ if benchmark_config.num_examples is None
+ else benchmark_config.num_examples
+ ),
)
res = await self.evaluate_rows(
benchmark_id=benchmark_id,
@@ -102,7 +106,7 @@ class MetaReferenceEvalImpl(
# need job scheduler queue (ray/celery) w/ jobs api
job_id = str(len(self.jobs))
self.jobs[job_id] = res
- return Job(job_id=job_id)
+ return Job(job_id=job_id, status=JobStatus.completed)
async def _run_agent_generation(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
@@ -115,10 +119,14 @@ class MetaReferenceEvalImpl(
for i, x in tqdm(enumerate(input_rows)):
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
- input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
+ input_messages = [
+ UserMessage(**x) for x in input_messages if x["role"] == "user"
+ ]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
- session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
+ session_create_response = await self.agents_api.create_agent_session(
+ agent_id, f"session-{i}"
+ )
session_id = session_create_response.session_id
turn_request = dict(
@@ -127,7 +135,12 @@ class MetaReferenceEvalImpl(
messages=input_messages,
stream=True,
)
- turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
+ turn_response = [
+ chunk
+ async for chunk in await self.agents_api.create_agent_turn(
+ **turn_request
+ )
+ ]
final_event = turn_response[-1].event.payload
# check if there's a memory retrieval step and extract the context
@@ -136,10 +149,14 @@ class MetaReferenceEvalImpl(
if step.step_type == StepType.tool_execution.value:
for tool_response in step.tool_responses:
if tool_response.tool_name == MEMORY_QUERY_TOOL:
- memory_rag_context = " ".join(x.text for x in tool_response.content)
+ memory_rag_context = " ".join(
+ x.text for x in tool_response.content
+ )
agent_generation = {}
- agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
+ agent_generation[ColumnName.generated_answer.value] = (
+ final_event.turn.output_message.content
+ )
if memory_rag_context:
agent_generation[ColumnName.context.value] = memory_rag_context
@@ -151,7 +168,9 @@ class MetaReferenceEvalImpl(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
) -> List[Dict[str, Any]]:
candidate = benchmark_config.eval_candidate
- assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
+ assert (
+ candidate.sampling_params.max_tokens is not None
+ ), "SamplingParams.max_tokens must be provided"
generations = []
for x in tqdm(input_rows):
@@ -162,21 +181,39 @@ class MetaReferenceEvalImpl(
content=input_content,
sampling_params=candidate.sampling_params,
)
- generations.append({ColumnName.generated_answer.value: response.completion_message.content})
+ generations.append(
+ {
+ ColumnName.generated_answer.value: response.completion_message.content
+ }
+ )
elif ColumnName.chat_completion_input.value in x:
- chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
- input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
+ chat_completion_input_json = json.loads(
+ x[ColumnName.chat_completion_input.value]
+ )
+ input_messages = [
+ UserMessage(**x)
+ for x in chat_completion_input_json
+ if x["role"] == "user"
+ ]
messages = []
if candidate.system_message:
messages.append(candidate.system_message)
- messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
+ messages += [
+ SystemMessage(**x)
+ for x in chat_completion_input_json
+ if x["role"] == "system"
+ ]
messages += input_messages
response = await self.inference_api.chat_completion(
model_id=candidate.model,
messages=messages,
sampling_params=candidate.sampling_params,
)
- generations.append({ColumnName.generated_answer.value: response.completion_message.content})
+ generations.append(
+ {
+ ColumnName.generated_answer.value: response.completion_message.content
+ }
+ )
else:
raise ValueError("Invalid input row")
@@ -199,7 +236,8 @@ class MetaReferenceEvalImpl(
# scoring with generated_answer
score_input_rows = [
- input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
+ input_r | generated_r
+ for input_r, generated_r in zip(input_rows, generations, strict=False)
]
if benchmark_config.scoring_params is not None:
@@ -208,7 +246,9 @@ class MetaReferenceEvalImpl(
for scoring_fn_id in scoring_functions
}
else:
- scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions}
+ scoring_functions_dict = {
+ scoring_fn_id: None for scoring_fn_id in scoring_functions
+ }
score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
@@ -216,17 +256,18 @@ class MetaReferenceEvalImpl(
return EvaluateResponse(generations=generations, scores=score_response.results)
- async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
+ async def job_status(self, benchmark_id: str, job_id: str) -> Job:
if job_id in self.jobs:
- return JobStatus.completed
+ return Job(job_id=job_id, status=JobStatus.completed)
- return None
+ raise ValueError(f"Job {job_id} not found")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
- status = await self.job_status(benchmark_id, job_id)
+ job = await self.job_status(benchmark_id, job_id)
+ status = job.status
if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}")
diff --git a/tests/integration/eval/test_eval.py b/tests/integration/eval/test_eval.py
index c4aa0fa1b..d1c3de519 100644
--- a/tests/integration/eval/test_eval.py
+++ b/tests/integration/eval/test_eval.py
@@ -94,7 +94,7 @@ def test_evaluate_benchmark(llama_stack_client, text_model_id, scoring_fn_id):
)
assert response.job_id == "0"
job_status = llama_stack_client.eval.jobs.status(job_id=response.job_id, benchmark_id=benchmark_id)
- assert job_status and job_status == "completed"
+ assert job_status and job_status.status == "completed"
eval_response = llama_stack_client.eval.jobs.retrieve(job_id=response.job_id, benchmark_id=benchmark_id)
assert eval_response is not None