precommit

This commit is contained in:
Xi Yan 2025-03-21 13:43:47 -07:00
parent 2f140c7ccf
commit ce6fc9f851
2 changed files with 47 additions and 156 deletions

View file

@ -8,9 +8,9 @@ import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL,
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
URL,
) )
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.apis.datasets import DatasetPurpose, DataSource
@ -88,9 +88,7 @@ class VectorIORouter(VectorIO):
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None,
) -> None: ) -> None:
logger.debug( logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}"
)
await self.routing_table.register_vector_db( await self.routing_table.register_vector_db(
vector_db_id, vector_db_id,
embedding_model, embedding_model,
@ -108,9 +106,7 @@ class VectorIORouter(VectorIO):
logger.debug( logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
) )
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks( return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
vector_db_id, chunks, ttl_seconds
)
async def query_chunks( async def query_chunks(
self, self,
@ -119,9 +115,7 @@ class VectorIORouter(VectorIO):
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks( return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
vector_db_id, query, params
)
class InferenceRouter(Inference): class InferenceRouter(Inference):
@ -158,9 +152,7 @@ class InferenceRouter(Inference):
logger.debug( logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
) )
await self.routing_table.register_model( await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
model_id, provider_model_id, provider_id, metadata, model_type
)
def _construct_metrics( def _construct_metrics(
self, self,
@ -214,16 +206,11 @@ class InferenceRouter(Inference):
total_tokens: int, total_tokens: int,
model: Model, model: Model,
) -> List[MetricInResponse]: ) -> List[MetricInResponse]:
metrics = self._construct_metrics( metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
prompt_tokens, completion_tokens, total_tokens, model
)
if self.telemetry: if self.telemetry:
for metric in metrics: for metric in metrics:
await self.telemetry.log_event(metric) await self.telemetry.log_event(metric)
return [ return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in metrics
]
async def _count_tokens( async def _count_tokens(
self, self,
@ -248,9 +235,7 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> Union[ ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
logger.debug( logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
) )
@ -260,19 +245,12 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
raise ValueError( raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
if tool_config: if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice: if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match") raise ValueError("tool_choice and tool_config.tool_choice must match")
if ( if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
tool_prompt_format raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
and tool_prompt_format != tool_config.tool_prompt_format
):
raise ValueError(
"tool_prompt_format and tool_config.tool_prompt_format must match"
)
else: else:
params = {} params = {}
if tool_choice: if tool_choice:
@ -290,14 +268,9 @@ class InferenceRouter(Inference):
pass pass
else: else:
# verify tool_choice is one of the tools # verify tool_choice is one of the tools
tool_names = [ tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
for t in tools
]
if tool_config.tool_choice not in tool_names: if tool_config.tool_choice not in tool_names:
raise ValueError( raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -312,25 +285,17 @@ class InferenceRouter(Inference):
tool_config=tool_config, tool_config=tool_config,
) )
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens( prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
messages, tool_config.tool_prompt_format
)
if stream: if stream:
async def stream_generator(): async def stream_generator():
completion_text = "" completion_text = ""
async for chunk in await provider.chat_completion(**params): async for chunk in await provider.chat_completion(**params):
if ( if chunk.event.event_type == ChatCompletionResponseEventType.progress:
chunk.event.event_type
== ChatCompletionResponseEventType.progress
):
if chunk.event.delta.type == "text": if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text completion_text += chunk.event.delta.text
if ( if chunk.event.event_type == ChatCompletionResponseEventType.complete:
chunk.event.event_type
== ChatCompletionResponseEventType.complete
):
completion_tokens = await self._count_tokens( completion_tokens = await self._count_tokens(
[ [
CompletionMessage( CompletionMessage(
@ -347,11 +312,7 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
chunk.metrics = ( chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
metrics
if chunk.metrics is None
else chunk.metrics + metrics
)
yield chunk yield chunk
return stream_generator() return stream_generator()
@ -368,9 +329,7 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
response.metrics = ( response.metrics = metrics if response.metrics is None else response.metrics + metrics
metrics if response.metrics is None else response.metrics + metrics
)
return response return response
async def completion( async def completion(
@ -391,9 +350,7 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
raise ValueError( raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -413,11 +370,7 @@ class InferenceRouter(Inference):
async for chunk in await provider.completion(**params): async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"): if hasattr(chunk, "delta"):
completion_text += chunk.delta completion_text += chunk.delta
if ( if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
hasattr(chunk, "stop_reason")
and chunk.stop_reason
and self.telemetry
):
completion_tokens = await self._count_tokens(completion_text) completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage( metrics = await self._compute_and_log_token_usage(
@ -426,11 +379,7 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
chunk.metrics = ( chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
metrics
if chunk.metrics is None
else chunk.metrics + metrics
)
yield chunk yield chunk
return stream_generator() return stream_generator()
@ -444,9 +393,7 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
response.metrics = ( response.metrics = metrics if response.metrics is None else response.metrics + metrics
metrics if response.metrics is None else response.metrics + metrics
)
return response return response
async def embeddings( async def embeddings(
@ -462,9 +409,7 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm: if model.model_type == ModelType.llm:
raise ValueError( raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
f"Model '{model_id}' is an LLM model and does not support embeddings"
)
return await self.routing_table.get_provider_impl(model_id).embeddings( return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id, model_id=model_id,
contents=contents, contents=contents,
@ -498,9 +443,7 @@ class SafetyRouter(Safety):
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> Shield: ) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}") logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield( return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
shield_id, provider_shield_id, provider_id, params
)
async def run_shield( async def run_shield(
self, self,
@ -597,9 +540,7 @@ class ScoringRouter(Scoring):
logger.debug(f"ScoringRouter.score_batch: {dataset_id}") logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {} res = {}
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl( score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
fn_identifier
).score_batch(
dataset_id=dataset_id, dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
) )
@ -617,15 +558,11 @@ class ScoringRouter(Scoring):
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ) -> ScoreResponse:
logger.debug( logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions"
)
res = {} res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl( score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
fn_identifier
).score(
input_rows=input_rows, input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
) )
@ -668,9 +605,7 @@ class EvalRouter(Eval):
scoring_functions: List[str], scoring_functions: List[str],
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> EvaluateResponse: ) -> EvaluateResponse:
logger.debug( logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows"
)
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
input_rows=input_rows, input_rows=input_rows,
@ -684,9 +619,7 @@ class EvalRouter(Eval):
job_id: str, job_id: str,
) -> Job: ) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status( return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
benchmark_id, job_id
)
async def job_cancel( async def job_cancel(
self, self,
@ -740,9 +673,9 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug( logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
) )
return await self.routing_table.get_provider_impl( return await self.routing_table.get_provider_impl("insert_into_memory").insert(
"insert_into_memory" documents, vector_db_id, chunk_size_in_tokens
).insert(documents, vector_db_id, chunk_size_in_tokens) )
def __init__( def __init__(
self, self,
@ -775,6 +708,4 @@ class ToolRuntimeRouter(ToolRuntime):
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> List[ToolDef]:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools( return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
tool_group_id, mcp_endpoint
)

View file

@ -89,11 +89,7 @@ class MetaReferenceEvalImpl(
all_rows = await self.datasetio_api.iterrows( all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
limit=( limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
-1
if benchmark_config.num_examples is None
else benchmark_config.num_examples
),
) )
res = await self.evaluate_rows( res = await self.evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
@ -119,14 +115,10 @@ class MetaReferenceEvalImpl(
for i, x in tqdm(enumerate(input_rows)): for i, x in tqdm(enumerate(input_rows)):
assert ColumnName.chat_completion_input.value in x, "Invalid input row" assert ColumnName.chat_completion_input.value in x, "Invalid input row"
input_messages = json.loads(x[ColumnName.chat_completion_input.value]) input_messages = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [ input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
UserMessage(**x) for x in input_messages if x["role"] == "user"
]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row # NOTE: only single-turn agent generation is supported. Create a new session for each input row
session_create_response = await self.agents_api.create_agent_session( session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
agent_id, f"session-{i}"
)
session_id = session_create_response.session_id session_id = session_create_response.session_id
turn_request = dict( turn_request = dict(
@ -135,12 +127,7 @@ class MetaReferenceEvalImpl(
messages=input_messages, messages=input_messages,
stream=True, stream=True,
) )
turn_response = [ turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
chunk
async for chunk in await self.agents_api.create_agent_turn(
**turn_request
)
]
final_event = turn_response[-1].event.payload final_event = turn_response[-1].event.payload
# check if there's a memory retrieval step and extract the context # check if there's a memory retrieval step and extract the context
@ -149,14 +136,10 @@ class MetaReferenceEvalImpl(
if step.step_type == StepType.tool_execution.value: if step.step_type == StepType.tool_execution.value:
for tool_response in step.tool_responses: for tool_response in step.tool_responses:
if tool_response.tool_name == MEMORY_QUERY_TOOL: if tool_response.tool_name == MEMORY_QUERY_TOOL:
memory_rag_context = " ".join( memory_rag_context = " ".join(x.text for x in tool_response.content)
x.text for x in tool_response.content
)
agent_generation = {} agent_generation = {}
agent_generation[ColumnName.generated_answer.value] = ( agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
final_event.turn.output_message.content
)
if memory_rag_context: if memory_rag_context:
agent_generation[ColumnName.context.value] = memory_rag_context agent_generation[ColumnName.context.value] = memory_rag_context
@ -168,9 +151,7 @@ class MetaReferenceEvalImpl(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
candidate = benchmark_config.eval_candidate candidate = benchmark_config.eval_candidate
assert ( assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
candidate.sampling_params.max_tokens is not None
), "SamplingParams.max_tokens must be provided"
generations = [] generations = []
for x in tqdm(input_rows): for x in tqdm(input_rows):
@ -181,39 +162,21 @@ class MetaReferenceEvalImpl(
content=input_content, content=input_content,
sampling_params=candidate.sampling_params, sampling_params=candidate.sampling_params,
) )
generations.append( generations.append({ColumnName.generated_answer.value: response.completion_message.content})
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
elif ColumnName.chat_completion_input.value in x: elif ColumnName.chat_completion_input.value in x:
chat_completion_input_json = json.loads( chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
x[ColumnName.chat_completion_input.value] input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
)
input_messages = [
UserMessage(**x)
for x in chat_completion_input_json
if x["role"] == "user"
]
messages = [] messages = []
if candidate.system_message: if candidate.system_message:
messages.append(candidate.system_message) messages.append(candidate.system_message)
messages += [ messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
SystemMessage(**x)
for x in chat_completion_input_json
if x["role"] == "system"
]
messages += input_messages messages += input_messages
response = await self.inference_api.chat_completion( response = await self.inference_api.chat_completion(
model_id=candidate.model, model_id=candidate.model,
messages=messages, messages=messages,
sampling_params=candidate.sampling_params, sampling_params=candidate.sampling_params,
) )
generations.append( generations.append({ColumnName.generated_answer.value: response.completion_message.content})
{
ColumnName.generated_answer.value: response.completion_message.content
}
)
else: else:
raise ValueError("Invalid input row") raise ValueError("Invalid input row")
@ -236,8 +199,7 @@ class MetaReferenceEvalImpl(
# scoring with generated_answer # scoring with generated_answer
score_input_rows = [ score_input_rows = [
input_r | generated_r input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
for input_r, generated_r in zip(input_rows, generations, strict=False)
] ]
if benchmark_config.scoring_params is not None: if benchmark_config.scoring_params is not None:
@ -246,9 +208,7 @@ class MetaReferenceEvalImpl(
for scoring_fn_id in scoring_functions for scoring_fn_id in scoring_functions
} }
else: else:
scoring_functions_dict = { scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions}
scoring_fn_id: None for scoring_fn_id in scoring_functions
}
score_response = await self.scoring_api.score( score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions_dict input_rows=score_input_rows, scoring_functions=scoring_functions_dict