precommit

This commit is contained in:
Xi Yan 2025-03-21 13:41:02 -07:00
parent d6887f46c6
commit 2f140c7ccf
7 changed files with 235 additions and 116 deletions

View file

@ -2183,7 +2183,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/JobStatus" "$ref": "#/components/schemas/Job"
} }
} }
} }
@ -7648,16 +7648,6 @@
"title": "PostTrainingJobArtifactsResponse", "title": "PostTrainingJobArtifactsResponse",
"description": "Artifacts of a finetuning job." "description": "Artifacts of a finetuning job."
}, },
"JobStatus": {
"type": "string",
"enum": [
"completed",
"in_progress",
"failed",
"scheduled"
],
"title": "JobStatus"
},
"PostTrainingJobStatusResponse": { "PostTrainingJobStatusResponse": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -7665,7 +7655,14 @@
"type": "string" "type": "string"
}, },
"status": { "status": {
"$ref": "#/components/schemas/JobStatus" "type": "string",
"enum": [
"completed",
"in_progress",
"failed",
"scheduled"
],
"title": "JobStatus"
}, },
"scheduled_at": { "scheduled_at": {
"type": "string", "type": "string",
@ -8115,6 +8112,30 @@
"title": "IterrowsResponse", "title": "IterrowsResponse",
"description": "A paginated list of rows from a dataset." "description": "A paginated list of rows from a dataset."
}, },
"Job": {
"type": "object",
"properties": {
"job_id": {
"type": "string"
},
"status": {
"type": "string",
"enum": [
"completed",
"in_progress",
"failed",
"scheduled"
],
"title": "JobStatus"
}
},
"additionalProperties": false,
"required": [
"job_id",
"status"
],
"title": "Job"
},
"ListAgentSessionsResponse": { "ListAgentSessionsResponse": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -9639,19 +9660,6 @@
], ],
"title": "RunEvalRequest" "title": "RunEvalRequest"
}, },
"Job": {
"type": "object",
"properties": {
"job_id": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"job_id"
],
"title": "Job"
},
"RunShieldRequest": { "RunShieldRequest": {
"type": "object", "type": "object",
"properties": { "properties": {

View file

@ -1491,7 +1491,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/JobStatus' $ref: '#/components/schemas/Job'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -5277,21 +5277,19 @@ components:
- checkpoints - checkpoints
title: PostTrainingJobArtifactsResponse title: PostTrainingJobArtifactsResponse
description: Artifacts of a finetuning job. description: Artifacts of a finetuning job.
JobStatus:
type: string
enum:
- completed
- in_progress
- failed
- scheduled
title: JobStatus
PostTrainingJobStatusResponse: PostTrainingJobStatusResponse:
type: object type: object
properties: properties:
job_uuid: job_uuid:
type: string type: string
status: status:
$ref: '#/components/schemas/JobStatus' type: string
enum:
- completed
- in_progress
- failed
- scheduled
title: JobStatus
scheduled_at: scheduled_at:
type: string type: string
format: date-time format: date-time
@ -5556,6 +5554,24 @@ components:
- data - data
title: IterrowsResponse title: IterrowsResponse
description: A paginated list of rows from a dataset. description: A paginated list of rows from a dataset.
Job:
type: object
properties:
job_id:
type: string
status:
type: string
enum:
- completed
- in_progress
- failed
- scheduled
title: JobStatus
additionalProperties: false
required:
- job_id
- status
title: Job
ListAgentSessionsResponse: ListAgentSessionsResponse:
type: object type: object
properties: properties:
@ -6550,15 +6566,6 @@ components:
required: required:
- benchmark_config - benchmark_config
title: RunEvalRequest title: RunEvalRequest
Job:
type: object
properties:
job_id:
type: string
additionalProperties: false
required:
- job_id
title: Job
RunShieldRequest: RunShieldRequest:
type: object type: object
properties: properties:

View file

@ -10,14 +10,14 @@ from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@json_schema_type
class Job(BaseModel):
job_id: str
@json_schema_type
class JobStatus(Enum): class JobStatus(Enum):
completed = "completed" completed = "completed"
in_progress = "in_progress" in_progress = "in_progress"
failed = "failed" failed = "failed"
scheduled = "scheduled" scheduled = "scheduled"
@json_schema_type
class Job(BaseModel):
job_id: str
status: JobStatus

View file

@ -10,7 +10,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job, JobStatus from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import SamplingParams, SystemMessage from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
@ -115,7 +115,7 @@ class Eval(Protocol):
""" """
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET") @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus: async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job. """Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on. :param benchmark_id: The ID of the benchmark to run the evaluation on.

View file

@ -8,19 +8,13 @@ import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL,
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
URL,
) )
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import ( from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
BenchmarkConfig,
Eval,
EvaluateResponse,
Job,
JobStatus,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
@ -94,7 +88,9 @@ class VectorIORouter(VectorIO):
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None,
) -> None: ) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") logger.debug(
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}"
)
await self.routing_table.register_vector_db( await self.routing_table.register_vector_db(
vector_db_id, vector_db_id,
embedding_model, embedding_model,
@ -112,7 +108,9 @@ class VectorIORouter(VectorIO):
logger.debug( logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
) )
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(
vector_db_id, chunks, ttl_seconds
)
async def query_chunks( async def query_chunks(
self, self,
@ -121,7 +119,9 @@ class VectorIORouter(VectorIO):
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(
vector_db_id, query, params
)
class InferenceRouter(Inference): class InferenceRouter(Inference):
@ -158,7 +158,9 @@ class InferenceRouter(Inference):
logger.debug( logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
) )
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata, model_type
)
def _construct_metrics( def _construct_metrics(
self, self,
@ -212,11 +214,16 @@ class InferenceRouter(Inference):
total_tokens: int, total_tokens: int,
model: Model, model: Model,
) -> List[MetricInResponse]: ) -> List[MetricInResponse]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) metrics = self._construct_metrics(
prompt_tokens, completion_tokens, total_tokens, model
)
if self.telemetry: if self.telemetry:
for metric in metrics: for metric in metrics:
await self.telemetry.log_event(metric) await self.telemetry.log_event(metric)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] return [
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in metrics
]
async def _count_tokens( async def _count_tokens(
self, self,
@ -241,7 +248,9 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: ) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
logger.debug( logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
) )
@ -251,12 +260,19 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
if tool_config: if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice: if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match") raise ValueError("tool_choice and tool_config.tool_choice must match")
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format: if (
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") tool_prompt_format
and tool_prompt_format != tool_config.tool_prompt_format
):
raise ValueError(
"tool_prompt_format and tool_config.tool_prompt_format must match"
)
else: else:
params = {} params = {}
if tool_choice: if tool_choice:
@ -274,9 +290,14 @@ class InferenceRouter(Inference):
pass pass
else: else:
# verify tool_choice is one of the tools # verify tool_choice is one of the tools
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] tool_names = [
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
for t in tools
]
if tool_config.tool_choice not in tool_names: if tool_config.tool_choice not in tool_names:
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") raise ValueError(
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -291,17 +312,25 @@ class InferenceRouter(Inference):
tool_config=tool_config, tool_config=tool_config,
) )
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) prompt_tokens = await self._count_tokens(
messages, tool_config.tool_prompt_format
)
if stream: if stream:
async def stream_generator(): async def stream_generator():
completion_text = "" completion_text = ""
async for chunk in await provider.chat_completion(**params): async for chunk in await provider.chat_completion(**params):
if chunk.event.event_type == ChatCompletionResponseEventType.progress: if (
chunk.event.event_type
== ChatCompletionResponseEventType.progress
):
if chunk.event.delta.type == "text": if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete: if (
chunk.event.event_type
== ChatCompletionResponseEventType.complete
):
completion_tokens = await self._count_tokens( completion_tokens = await self._count_tokens(
[ [
CompletionMessage( CompletionMessage(
@ -318,7 +347,11 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics chunk.metrics = (
metrics
if chunk.metrics is None
else chunk.metrics + metrics
)
yield chunk yield chunk
return stream_generator() return stream_generator()
@ -335,7 +368,9 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
response.metrics = metrics if response.metrics is None else response.metrics + metrics response.metrics = (
metrics if response.metrics is None else response.metrics + metrics
)
return response return response
async def completion( async def completion(
@ -356,7 +391,9 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -376,7 +413,11 @@ class InferenceRouter(Inference):
async for chunk in await provider.completion(**params): async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"): if hasattr(chunk, "delta"):
completion_text += chunk.delta completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: if (
hasattr(chunk, "stop_reason")
and chunk.stop_reason
and self.telemetry
):
completion_tokens = await self._count_tokens(completion_text) completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage( metrics = await self._compute_and_log_token_usage(
@ -385,7 +426,11 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics chunk.metrics = (
metrics
if chunk.metrics is None
else chunk.metrics + metrics
)
yield chunk yield chunk
return stream_generator() return stream_generator()
@ -399,7 +444,9 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
response.metrics = metrics if response.metrics is None else response.metrics + metrics response.metrics = (
metrics if response.metrics is None else response.metrics + metrics
)
return response return response
async def embeddings( async def embeddings(
@ -415,7 +462,9 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm: if model.model_type == ModelType.llm:
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") raise ValueError(
f"Model '{model_id}' is an LLM model and does not support embeddings"
)
return await self.routing_table.get_provider_impl(model_id).embeddings( return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id, model_id=model_id,
contents=contents, contents=contents,
@ -449,7 +498,9 @@ class SafetyRouter(Safety):
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> Shield: ) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}") logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) return await self.routing_table.register_shield(
shield_id, provider_shield_id, provider_id, params
)
async def run_shield( async def run_shield(
self, self,
@ -546,7 +597,9 @@ class ScoringRouter(Scoring):
logger.debug(f"ScoringRouter.score_batch: {dataset_id}") logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {} res = {}
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score_batch(
dataset_id=dataset_id, dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
) )
@ -564,11 +617,15 @@ class ScoringRouter(Scoring):
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ) -> ScoreResponse:
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") logger.debug(
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions"
)
res = {} res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score( score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score(
input_rows=input_rows, input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
) )
@ -611,7 +668,9 @@ class EvalRouter(Eval):
scoring_functions: List[str], scoring_functions: List[str],
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> EvaluateResponse: ) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") logger.debug(
f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows"
)
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
input_rows=input_rows, input_rows=input_rows,
@ -623,9 +682,11 @@ class EvalRouter(Eval):
self, self,
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> Optional[JobStatus]: ) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) return await self.routing_table.get_provider_impl(benchmark_id).job_status(
benchmark_id, job_id
)
async def job_cancel( async def job_cancel(
self, self,
@ -679,9 +740,9 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug( logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
) )
return await self.routing_table.get_provider_impl("insert_into_memory").insert( return await self.routing_table.get_provider_impl(
documents, vector_db_id, chunk_size_in_tokens "insert_into_memory"
) ).insert(documents, vector_db_id, chunk_size_in_tokens)
def __init__( def __init__(
self, self,
@ -714,4 +775,6 @@ class ToolRuntimeRouter(ToolRuntime):
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> List[ToolDef]:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
tool_group_id, mcp_endpoint
)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List
from tqdm import tqdm from tqdm import tqdm
@ -21,8 +21,8 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
from llama_stack.providers.utils.common.data_schema_validator import ColumnName from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job from .....apis.common.job_types import Job, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse, JobStatus from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .config import MetaReferenceEvalConfig from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "benchmarks:" EVAL_TASKS_PREFIX = "benchmarks:"
@ -89,7 +89,11 @@ class MetaReferenceEvalImpl(
all_rows = await self.datasetio_api.iterrows( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples), limit=(
-1
if benchmark_config.num_examples is None
else benchmark_config.num_examples
),
) )
res = await self.evaluate_rows( res = await self.evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
@ -102,7 +106,7 @@ class MetaReferenceEvalImpl(
# need job scheduler queue (ray/celery) w/ jobs api # need job scheduler queue (ray/celery) w/ jobs api
job_id = str(len(self.jobs)) job_id = str(len(self.jobs))
self.jobs[job_id] = res self.jobs[job_id] = res
return Job(job_id=job_id) return Job(job_id=job_id, status=JobStatus.completed)
async def _run_agent_generation( async def _run_agent_generation(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
@ -115,10 +119,14 @@ class MetaReferenceEvalImpl(
for i, x in tqdm(enumerate(input_rows)): for i, x in tqdm(enumerate(input_rows)):
assert ColumnName.chat_completion_input.value in x, "Invalid input row" assert ColumnName.chat_completion_input.value in x, "Invalid input row"
input_messages = json.loads(x[ColumnName.chat_completion_input.value]) input_messages = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"] input_messages = [
UserMessage(**x) for x in input_messages if x["role"] == "user"
]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row # NOTE: only single-turn agent generation is supported. Create a new session for each input row
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}") session_create_response = await self.agents_api.create_agent_session(
agent_id, f"session-{i}"
)
session_id = session_create_response.session_id session_id = session_create_response.session_id
turn_request = dict( turn_request = dict(
@ -127,7 +135,12 @@ class MetaReferenceEvalImpl(
messages=input_messages, messages=input_messages,
stream=True, stream=True,
) )
turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)] turn_response = [
chunk
async for chunk in await self.agents_api.create_agent_turn(
**turn_request
)
]
final_event = turn_response[-1].event.payload final_event = turn_response[-1].event.payload
# check if there's a memory retrieval step and extract the context # check if there's a memory retrieval step and extract the context
@ -136,10 +149,14 @@ class MetaReferenceEvalImpl(
if step.step_type == StepType.tool_execution.value: if step.step_type == StepType.tool_execution.value:
for tool_response in step.tool_responses: for tool_response in step.tool_responses:
if tool_response.tool_name == MEMORY_QUERY_TOOL: if tool_response.tool_name == MEMORY_QUERY_TOOL:
memory_rag_context = " ".join(x.text for x in tool_response.content) memory_rag_context = " ".join(
x.text for x in tool_response.content
)
agent_generation = {} agent_generation = {}
agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content agent_generation[ColumnName.generated_answer.value] = (
final_event.turn.output_message.content
)
if memory_rag_context: if memory_rag_context:
agent_generation[ColumnName.context.value] = memory_rag_context agent_generation[ColumnName.context.value] = memory_rag_context
@ -151,7 +168,9 @@ class MetaReferenceEvalImpl(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
candidate = benchmark_config.eval_candidate candidate = benchmark_config.eval_candidate
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided" assert (
candidate.sampling_params.max_tokens is not None
), "SamplingParams.max_tokens must be provided"
generations = [] generations = []
for x in tqdm(input_rows): for x in tqdm(input_rows):
@ -162,21 +181,39 @@ class MetaReferenceEvalImpl(
content=input_content, content=input_content,
sampling_params=candidate.sampling_params, sampling_params=candidate.sampling_params,
) )
generations.append({ColumnName.generated_answer.value: response.completion_message.content}) generations.append(
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
elif ColumnName.chat_completion_input.value in x: elif ColumnName.chat_completion_input.value in x:
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value]) chat_completion_input_json = json.loads(
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"] x[ColumnName.chat_completion_input.value]
)
input_messages = [
UserMessage(**x)
for x in chat_completion_input_json
if x["role"] == "user"
]
messages = [] messages = []
if candidate.system_message: if candidate.system_message:
messages.append(candidate.system_message) messages.append(candidate.system_message)
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"] messages += [
SystemMessage(**x)
for x in chat_completion_input_json
if x["role"] == "system"
]
messages += input_messages messages += input_messages
response = await self.inference_api.chat_completion( response = await self.inference_api.chat_completion(
model_id=candidate.model, model_id=candidate.model,
messages=messages, messages=messages,
sampling_params=candidate.sampling_params, sampling_params=candidate.sampling_params,
) )
generations.append({ColumnName.generated_answer.value: response.completion_message.content}) generations.append(
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
else: else:
raise ValueError("Invalid input row") raise ValueError("Invalid input row")
@ -199,7 +236,8 @@ class MetaReferenceEvalImpl(
# scoring with generated_answer # scoring with generated_answer
score_input_rows = [ score_input_rows = [
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False) input_r | generated_r
for input_r, generated_r in zip(input_rows, generations, strict=False)
] ]
if benchmark_config.scoring_params is not None: if benchmark_config.scoring_params is not None:
@ -208,7 +246,9 @@ class MetaReferenceEvalImpl(
for scoring_fn_id in scoring_functions for scoring_fn_id in scoring_functions
} }
else: else:
scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions} scoring_functions_dict = {
scoring_fn_id: None for scoring_fn_id in scoring_functions
}
score_response = await self.scoring_api.score( score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions_dict input_rows=score_input_rows, scoring_functions=scoring_functions_dict
@ -216,17 +256,18 @@ class MetaReferenceEvalImpl(
return EvaluateResponse(generations=generations, scores=score_response.results) return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: async def job_status(self, benchmark_id: str, job_id: str) -> Job:
if job_id in self.jobs: if job_id in self.jobs:
return JobStatus.completed return Job(job_id=job_id, status=JobStatus.completed)
return None raise ValueError(f"Job {job_id} not found")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None: async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet") raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
status = await self.job_status(benchmark_id, job_id) job = await self.job_status(benchmark_id, job_id)
status = job.status
if not status or status != JobStatus.completed: if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}") raise ValueError(f"Job is not completed, Status: {status.value}")

View file

@ -94,7 +94,7 @@ def test_evaluate_benchmark(llama_stack_client, text_model_id, scoring_fn_id):
) )
assert response.job_id == "0" assert response.job_id == "0"
job_status = llama_stack_client.eval.jobs.status(job_id=response.job_id, benchmark_id=benchmark_id) job_status = llama_stack_client.eval.jobs.status(job_id=response.job_id, benchmark_id=benchmark_id)
assert job_status and job_status == "completed" assert job_status and job_status.status == "completed"
eval_response = llama_stack_client.eval.jobs.retrieve(job_id=response.job_id, benchmark_id=benchmark_id) eval_response = llama_stack_client.eval.jobs.retrieve(job_id=response.job_id, benchmark_id=benchmark_id)
assert eval_response is not None assert eval_response is not None