mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
precommit
This commit is contained in:
parent
d6887f46c6
commit
2f140c7ccf
7 changed files with 235 additions and 116 deletions
58
docs/_static/llama-stack-spec.html
vendored
58
docs/_static/llama-stack-spec.html
vendored
|
@ -2183,7 +2183,7 @@
|
|||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/JobStatus"
|
||||
"$ref": "#/components/schemas/Job"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7648,16 +7648,6 @@
|
|||
"title": "PostTrainingJobArtifactsResponse",
|
||||
"description": "Artifacts of a finetuning job."
|
||||
},
|
||||
"JobStatus": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"completed",
|
||||
"in_progress",
|
||||
"failed",
|
||||
"scheduled"
|
||||
],
|
||||
"title": "JobStatus"
|
||||
},
|
||||
"PostTrainingJobStatusResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -7665,7 +7655,14 @@
|
|||
"type": "string"
|
||||
},
|
||||
"status": {
|
||||
"$ref": "#/components/schemas/JobStatus"
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"completed",
|
||||
"in_progress",
|
||||
"failed",
|
||||
"scheduled"
|
||||
],
|
||||
"title": "JobStatus"
|
||||
},
|
||||
"scheduled_at": {
|
||||
"type": "string",
|
||||
|
@ -8115,6 +8112,30 @@
|
|||
"title": "IterrowsResponse",
|
||||
"description": "A paginated list of rows from a dataset."
|
||||
},
|
||||
"Job": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"job_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"completed",
|
||||
"in_progress",
|
||||
"failed",
|
||||
"scheduled"
|
||||
],
|
||||
"title": "JobStatus"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"job_id",
|
||||
"status"
|
||||
],
|
||||
"title": "Job"
|
||||
},
|
||||
"ListAgentSessionsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -9639,19 +9660,6 @@
|
|||
],
|
||||
"title": "RunEvalRequest"
|
||||
},
|
||||
"Job": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"job_id": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"job_id"
|
||||
],
|
||||
"title": "Job"
|
||||
},
|
||||
"RunShieldRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
45
docs/_static/llama-stack-spec.yaml
vendored
45
docs/_static/llama-stack-spec.yaml
vendored
|
@ -1491,7 +1491,7 @@ paths:
|
|||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/JobStatus'
|
||||
$ref: '#/components/schemas/Job'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
|
@ -5277,21 +5277,19 @@ components:
|
|||
- checkpoints
|
||||
title: PostTrainingJobArtifactsResponse
|
||||
description: Artifacts of a finetuning job.
|
||||
JobStatus:
|
||||
type: string
|
||||
enum:
|
||||
- completed
|
||||
- in_progress
|
||||
- failed
|
||||
- scheduled
|
||||
title: JobStatus
|
||||
PostTrainingJobStatusResponse:
|
||||
type: object
|
||||
properties:
|
||||
job_uuid:
|
||||
type: string
|
||||
status:
|
||||
$ref: '#/components/schemas/JobStatus'
|
||||
type: string
|
||||
enum:
|
||||
- completed
|
||||
- in_progress
|
||||
- failed
|
||||
- scheduled
|
||||
title: JobStatus
|
||||
scheduled_at:
|
||||
type: string
|
||||
format: date-time
|
||||
|
@ -5556,6 +5554,24 @@ components:
|
|||
- data
|
||||
title: IterrowsResponse
|
||||
description: A paginated list of rows from a dataset.
|
||||
Job:
|
||||
type: object
|
||||
properties:
|
||||
job_id:
|
||||
type: string
|
||||
status:
|
||||
type: string
|
||||
enum:
|
||||
- completed
|
||||
- in_progress
|
||||
- failed
|
||||
- scheduled
|
||||
title: JobStatus
|
||||
additionalProperties: false
|
||||
required:
|
||||
- job_id
|
||||
- status
|
||||
title: Job
|
||||
ListAgentSessionsResponse:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -6550,15 +6566,6 @@ components:
|
|||
required:
|
||||
- benchmark_config
|
||||
title: RunEvalRequest
|
||||
Job:
|
||||
type: object
|
||||
properties:
|
||||
job_id:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- job_id
|
||||
title: Job
|
||||
RunShieldRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
@ -10,14 +10,14 @@ from pydantic import BaseModel
|
|||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Job(BaseModel):
|
||||
job_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class JobStatus(Enum):
|
||||
completed = "completed"
|
||||
in_progress = "in_progress"
|
||||
failed = "failed"
|
||||
scheduled = "scheduled"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Job(BaseModel):
|
||||
job_id: str
|
||||
status: JobStatus
|
||||
|
|
|
@ -10,7 +10,7 @@ from pydantic import BaseModel, Field
|
|||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig
|
||||
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||
from llama_stack.apis.scoring import ScoringResult
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
@ -115,7 +115,7 @@ class Eval(Protocol):
|
|||
"""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus:
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
"""Get the status of a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
|
|
|
@ -8,19 +8,13 @@ import time
|
|||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||
from llama_stack.apis.eval import (
|
||||
BenchmarkConfig,
|
||||
Eval,
|
||||
EvaluateResponse,
|
||||
Job,
|
||||
JobStatus,
|
||||
)
|
||||
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
|
@ -94,7 +88,9 @@ class VectorIORouter(VectorIO):
|
|||
provider_id: Optional[str] = None,
|
||||
provider_vector_db_id: Optional[str] = None,
|
||||
) -> None:
|
||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||
logger.debug(
|
||||
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}"
|
||||
)
|
||||
await self.routing_table.register_vector_db(
|
||||
vector_db_id,
|
||||
embedding_model,
|
||||
|
@ -112,7 +108,9 @@ class VectorIORouter(VectorIO):
|
|||
logger.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(
|
||||
vector_db_id, chunks, ttl_seconds
|
||||
)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
|
@ -121,7 +119,9 @@ class VectorIORouter(VectorIO):
|
|||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryChunksResponse:
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(
|
||||
vector_db_id, query, params
|
||||
)
|
||||
|
||||
|
||||
class InferenceRouter(Inference):
|
||||
|
@ -158,7 +158,9 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||
)
|
||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
await self.routing_table.register_model(
|
||||
model_id, provider_model_id, provider_id, metadata, model_type
|
||||
)
|
||||
|
||||
def _construct_metrics(
|
||||
self,
|
||||
|
@ -212,11 +214,16 @@ class InferenceRouter(Inference):
|
|||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> List[MetricInResponse]:
|
||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens, completion_tokens, total_tokens, model
|
||||
)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
return [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value)
|
||||
for metric in metrics
|
||||
]
|
||||
|
||||
async def _count_tokens(
|
||||
self,
|
||||
|
@ -241,7 +248,9 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
|
@ -251,12 +260,19 @@ class InferenceRouter(Inference):
|
|||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||
)
|
||||
if tool_config:
|
||||
if tool_choice and tool_choice != tool_config.tool_choice:
|
||||
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
|
||||
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
|
||||
if (
|
||||
tool_prompt_format
|
||||
and tool_prompt_format != tool_config.tool_prompt_format
|
||||
):
|
||||
raise ValueError(
|
||||
"tool_prompt_format and tool_config.tool_prompt_format must match"
|
||||
)
|
||||
else:
|
||||
params = {}
|
||||
if tool_choice:
|
||||
|
@ -274,9 +290,14 @@ class InferenceRouter(Inference):
|
|||
pass
|
||||
else:
|
||||
# verify tool_choice is one of the tools
|
||||
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
|
||||
tool_names = [
|
||||
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
|
||||
for t in tools
|
||||
]
|
||||
if tool_config.tool_choice not in tool_names:
|
||||
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
|
||||
raise ValueError(
|
||||
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
|
||||
)
|
||||
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
|
@ -291,17 +312,25 @@ class InferenceRouter(Inference):
|
|||
tool_config=tool_config,
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||
prompt_tokens = await self._count_tokens(
|
||||
messages, tool_config.tool_prompt_format
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.chat_completion(**params):
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||
if (
|
||||
chunk.event.event_type
|
||||
== ChatCompletionResponseEventType.progress
|
||||
):
|
||||
if chunk.event.delta.type == "text":
|
||||
completion_text += chunk.event.delta.text
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||
if (
|
||||
chunk.event.event_type
|
||||
== ChatCompletionResponseEventType.complete
|
||||
):
|
||||
completion_tokens = await self._count_tokens(
|
||||
[
|
||||
CompletionMessage(
|
||||
|
@ -318,7 +347,11 @@ class InferenceRouter(Inference):
|
|||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
chunk.metrics = (
|
||||
metrics
|
||||
if chunk.metrics is None
|
||||
else chunk.metrics + metrics
|
||||
)
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
|
@ -335,7 +368,9 @@ class InferenceRouter(Inference):
|
|||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
response.metrics = (
|
||||
metrics if response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
async def completion(
|
||||
|
@ -356,7 +391,9 @@ class InferenceRouter(Inference):
|
|||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
|
@ -376,7 +413,11 @@ class InferenceRouter(Inference):
|
|||
async for chunk in await provider.completion(**params):
|
||||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||
if (
|
||||
hasattr(chunk, "stop_reason")
|
||||
and chunk.stop_reason
|
||||
and self.telemetry
|
||||
):
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
|
@ -385,7 +426,11 @@ class InferenceRouter(Inference):
|
|||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
chunk.metrics = (
|
||||
metrics
|
||||
if chunk.metrics is None
|
||||
else chunk.metrics + metrics
|
||||
)
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
|
@ -399,7 +444,9 @@ class InferenceRouter(Inference):
|
|||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
response.metrics = (
|
||||
metrics if response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
async def embeddings(
|
||||
|
@ -415,7 +462,9 @@ class InferenceRouter(Inference):
|
|||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.llm:
|
||||
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is an LLM model and does not support embeddings"
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||
model_id=model_id,
|
||||
contents=contents,
|
||||
|
@ -449,7 +498,9 @@ class SafetyRouter(Safety):
|
|||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Shield:
|
||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||
return await self.routing_table.register_shield(
|
||||
shield_id, provider_shield_id, provider_id, params
|
||||
)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
@ -546,7 +597,9 @@ class ScoringRouter(Scoring):
|
|||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||
score_response = await self.routing_table.get_provider_impl(
|
||||
fn_identifier
|
||||
).score_batch(
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
|
@ -564,11 +617,15 @@ class ScoringRouter(Scoring):
|
|||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
) -> ScoreResponse:
|
||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||
logger.debug(
|
||||
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions"
|
||||
)
|
||||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
||||
score_response = await self.routing_table.get_provider_impl(
|
||||
fn_identifier
|
||||
).score(
|
||||
input_rows=input_rows,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
|
@ -611,7 +668,9 @@ class EvalRouter(Eval):
|
|||
scoring_functions: List[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
logger.debug(
|
||||
f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows"
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=input_rows,
|
||||
|
@ -623,9 +682,11 @@ class EvalRouter(Eval):
|
|||
self,
|
||||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
) -> Job:
|
||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(
|
||||
benchmark_id, job_id
|
||||
)
|
||||
|
||||
async def job_cancel(
|
||||
self,
|
||||
|
@ -679,9 +740,9 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
logger.debug(
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
)
|
||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||
documents, vector_db_id, chunk_size_in_tokens
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(
|
||||
"insert_into_memory"
|
||||
).insert(documents, vector_db_id, chunk_size_in_tokens)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -714,4 +775,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
|
||||
tool_group_id, mcp_endpoint
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@ -21,8 +21,8 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
|||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .....apis.common.job_types import Job
|
||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse, JobStatus
|
||||
from .....apis.common.job_types import Job, JobStatus
|
||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
EVAL_TASKS_PREFIX = "benchmarks:"
|
||||
|
@ -89,7 +89,11 @@ class MetaReferenceEvalImpl(
|
|||
|
||||
all_rows = await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||
limit=(
|
||||
-1
|
||||
if benchmark_config.num_examples is None
|
||||
else benchmark_config.num_examples
|
||||
),
|
||||
)
|
||||
res = await self.evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
|
@ -102,7 +106,7 @@ class MetaReferenceEvalImpl(
|
|||
# need job scheduler queue (ray/celery) w/ jobs api
|
||||
job_id = str(len(self.jobs))
|
||||
self.jobs[job_id] = res
|
||||
return Job(job_id=job_id)
|
||||
return Job(job_id=job_id, status=JobStatus.completed)
|
||||
|
||||
async def _run_agent_generation(
|
||||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
|
@ -115,10 +119,14 @@ class MetaReferenceEvalImpl(
|
|||
for i, x in tqdm(enumerate(input_rows)):
|
||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
|
||||
input_messages = [
|
||||
UserMessage(**x) for x in input_messages if x["role"] == "user"
|
||||
]
|
||||
|
||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
|
||||
session_create_response = await self.agents_api.create_agent_session(
|
||||
agent_id, f"session-{i}"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
|
||||
turn_request = dict(
|
||||
|
@ -127,7 +135,12 @@ class MetaReferenceEvalImpl(
|
|||
messages=input_messages,
|
||||
stream=True,
|
||||
)
|
||||
turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
|
||||
turn_response = [
|
||||
chunk
|
||||
async for chunk in await self.agents_api.create_agent_turn(
|
||||
**turn_request
|
||||
)
|
||||
]
|
||||
final_event = turn_response[-1].event.payload
|
||||
|
||||
# check if there's a memory retrieval step and extract the context
|
||||
|
@ -136,10 +149,14 @@ class MetaReferenceEvalImpl(
|
|||
if step.step_type == StepType.tool_execution.value:
|
||||
for tool_response in step.tool_responses:
|
||||
if tool_response.tool_name == MEMORY_QUERY_TOOL:
|
||||
memory_rag_context = " ".join(x.text for x in tool_response.content)
|
||||
memory_rag_context = " ".join(
|
||||
x.text for x in tool_response.content
|
||||
)
|
||||
|
||||
agent_generation = {}
|
||||
agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
|
||||
agent_generation[ColumnName.generated_answer.value] = (
|
||||
final_event.turn.output_message.content
|
||||
)
|
||||
if memory_rag_context:
|
||||
agent_generation[ColumnName.context.value] = memory_rag_context
|
||||
|
||||
|
@ -151,7 +168,9 @@ class MetaReferenceEvalImpl(
|
|||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> List[Dict[str, Any]]:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
||||
assert (
|
||||
candidate.sampling_params.max_tokens is not None
|
||||
), "SamplingParams.max_tokens must be provided"
|
||||
|
||||
generations = []
|
||||
for x in tqdm(input_rows):
|
||||
|
@ -162,21 +181,39 @@ class MetaReferenceEvalImpl(
|
|||
content=input_content,
|
||||
sampling_params=candidate.sampling_params,
|
||||
)
|
||||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||
generations.append(
|
||||
{
|
||||
ColumnName.generated_answer.value: response.completion_message.content
|
||||
}
|
||||
)
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
|
||||
chat_completion_input_json = json.loads(
|
||||
x[ColumnName.chat_completion_input.value]
|
||||
)
|
||||
input_messages = [
|
||||
UserMessage(**x)
|
||||
for x in chat_completion_input_json
|
||||
if x["role"] == "user"
|
||||
]
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||
messages += [
|
||||
SystemMessage(**x)
|
||||
for x in chat_completion_input_json
|
||||
if x["role"] == "system"
|
||||
]
|
||||
messages += input_messages
|
||||
response = await self.inference_api.chat_completion(
|
||||
model_id=candidate.model,
|
||||
messages=messages,
|
||||
sampling_params=candidate.sampling_params,
|
||||
)
|
||||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||
generations.append(
|
||||
{
|
||||
ColumnName.generated_answer.value: response.completion_message.content
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid input row")
|
||||
|
||||
|
@ -199,7 +236,8 @@ class MetaReferenceEvalImpl(
|
|||
|
||||
# scoring with generated_answer
|
||||
score_input_rows = [
|
||||
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
|
||||
input_r | generated_r
|
||||
for input_r, generated_r in zip(input_rows, generations, strict=False)
|
||||
]
|
||||
|
||||
if benchmark_config.scoring_params is not None:
|
||||
|
@ -208,7 +246,9 @@ class MetaReferenceEvalImpl(
|
|||
for scoring_fn_id in scoring_functions
|
||||
}
|
||||
else:
|
||||
scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions}
|
||||
scoring_functions_dict = {
|
||||
scoring_fn_id: None for scoring_fn_id in scoring_functions
|
||||
}
|
||||
|
||||
score_response = await self.scoring_api.score(
|
||||
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
||||
|
@ -216,17 +256,18 @@ class MetaReferenceEvalImpl(
|
|||
|
||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
if job_id in self.jobs:
|
||||
return JobStatus.completed
|
||||
return Job(job_id=job_id, status=JobStatus.completed)
|
||||
|
||||
return None
|
||||
raise ValueError(f"Job {job_id} not found")
|
||||
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
status = await self.job_status(benchmark_id, job_id)
|
||||
job = await self.job_status(benchmark_id, job_id)
|
||||
status = job.status
|
||||
if not status or status != JobStatus.completed:
|
||||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||
|
||||
|
|
|
@ -94,7 +94,7 @@ def test_evaluate_benchmark(llama_stack_client, text_model_id, scoring_fn_id):
|
|||
)
|
||||
assert response.job_id == "0"
|
||||
job_status = llama_stack_client.eval.jobs.status(job_id=response.job_id, benchmark_id=benchmark_id)
|
||||
assert job_status and job_status == "completed"
|
||||
assert job_status and job_status.status == "completed"
|
||||
|
||||
eval_response = llama_stack_client.eval.jobs.retrieve(job_id=response.job_id, benchmark_id=benchmark_id)
|
||||
assert eval_response is not None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue