diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 1a98888a2..6ff36a65c 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -8,9 +8,9 @@ import time from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from llama_stack.apis.common.content_types import ( + URL, InterleavedContent, InterleavedContentItem, - URL, ) from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import DatasetPurpose, DataSource @@ -88,9 +88,7 @@ class VectorIORouter(VectorIO): provider_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None, ) -> None: - logger.debug( - f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}" - ) + logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") await self.routing_table.register_vector_db( vector_db_id, embedding_model, @@ -108,9 +106,7 @@ class VectorIORouter(VectorIO): logger.debug( f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", ) - return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks( - vector_db_id, chunks, ttl_seconds - ) + return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) async def query_chunks( self, @@ -119,9 +115,7 @@ class VectorIORouter(VectorIO): params: Optional[Dict[str, Any]] = None, ) -> QueryChunksResponse: logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") - return await self.routing_table.get_provider_impl(vector_db_id).query_chunks( - vector_db_id, query, params - ) + return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) class InferenceRouter(Inference): @@ -158,9 +152,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", ) - await self.routing_table.register_model( - model_id, provider_model_id, provider_id, metadata, model_type - ) + await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) def _construct_metrics( self, @@ -214,16 +206,11 @@ class InferenceRouter(Inference): total_tokens: int, model: Model, ) -> List[MetricInResponse]: - metrics = self._construct_metrics( - prompt_tokens, completion_tokens, total_tokens, model - ) + metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) if self.telemetry: for metric in metrics: await self.telemetry.log_event(metric) - return [ - MetricInResponse(metric=metric.metric, value=metric.value) - for metric in metrics - ] + return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] async def _count_tokens( self, @@ -248,9 +235,7 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> Union[ - ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] - ]: + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: logger.debug( f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) @@ -260,19 +245,12 @@ class InferenceRouter(Inference): if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: - raise ValueError( - f"Model '{model_id}' is an embedding model and does not support chat completions" - ) + raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") if tool_config: if tool_choice and tool_choice != tool_config.tool_choice: raise ValueError("tool_choice and tool_config.tool_choice must match") - if ( - tool_prompt_format - and tool_prompt_format != tool_config.tool_prompt_format - ): - raise ValueError( - "tool_prompt_format and tool_config.tool_prompt_format must match" - ) + if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format: + raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") else: params = {} if tool_choice: @@ -290,14 +268,9 @@ class InferenceRouter(Inference): pass else: # verify tool_choice is one of the tools - tool_names = [ - t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value - for t in tools - ] + tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] if tool_config.tool_choice not in tool_names: - raise ValueError( - f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}" - ) + raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") params = dict( model_id=model_id, @@ -312,25 +285,17 @@ class InferenceRouter(Inference): tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) - prompt_tokens = await self._count_tokens( - messages, tool_config.tool_prompt_format - ) + prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) if stream: async def stream_generator(): completion_text = "" async for chunk in await provider.chat_completion(**params): - if ( - chunk.event.event_type - == ChatCompletionResponseEventType.progress - ): + if chunk.event.event_type == ChatCompletionResponseEventType.progress: if chunk.event.delta.type == "text": completion_text += chunk.event.delta.text - if ( - chunk.event.event_type - == ChatCompletionResponseEventType.complete - ): + if chunk.event.event_type == ChatCompletionResponseEventType.complete: completion_tokens = await self._count_tokens( [ CompletionMessage( @@ -347,11 +312,7 @@ class InferenceRouter(Inference): total_tokens, model, ) - chunk.metrics = ( - metrics - if chunk.metrics is None - else chunk.metrics + metrics - ) + chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics yield chunk return stream_generator() @@ -368,9 +329,7 @@ class InferenceRouter(Inference): total_tokens, model, ) - response.metrics = ( - metrics if response.metrics is None else response.metrics + metrics - ) + response.metrics = metrics if response.metrics is None else response.metrics + metrics return response async def completion( @@ -391,9 +350,7 @@ class InferenceRouter(Inference): if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: - raise ValueError( - f"Model '{model_id}' is an embedding model and does not support chat completions" - ) + raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") provider = self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, @@ -413,11 +370,7 @@ class InferenceRouter(Inference): async for chunk in await provider.completion(**params): if hasattr(chunk, "delta"): completion_text += chunk.delta - if ( - hasattr(chunk, "stop_reason") - and chunk.stop_reason - and self.telemetry - ): + if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: completion_tokens = await self._count_tokens(completion_text) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) metrics = await self._compute_and_log_token_usage( @@ -426,11 +379,7 @@ class InferenceRouter(Inference): total_tokens, model, ) - chunk.metrics = ( - metrics - if chunk.metrics is None - else chunk.metrics + metrics - ) + chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics yield chunk return stream_generator() @@ -444,9 +393,7 @@ class InferenceRouter(Inference): total_tokens, model, ) - response.metrics = ( - metrics if response.metrics is None else response.metrics + metrics - ) + response.metrics = metrics if response.metrics is None else response.metrics + metrics return response async def embeddings( @@ -462,9 +409,7 @@ class InferenceRouter(Inference): if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.llm: - raise ValueError( - f"Model '{model_id}' is an LLM model and does not support embeddings" - ) + raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") return await self.routing_table.get_provider_impl(model_id).embeddings( model_id=model_id, contents=contents, @@ -498,9 +443,7 @@ class SafetyRouter(Safety): params: Optional[Dict[str, Any]] = None, ) -> Shield: logger.debug(f"SafetyRouter.register_shield: {shield_id}") - return await self.routing_table.register_shield( - shield_id, provider_shield_id, provider_id, params - ) + return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) async def run_shield( self, @@ -597,9 +540,7 @@ class ScoringRouter(Scoring): logger.debug(f"ScoringRouter.score_batch: {dataset_id}") res = {} for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl( - fn_identifier - ).score_batch( + score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( dataset_id=dataset_id, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -617,15 +558,11 @@ class ScoringRouter(Scoring): input_rows: List[Dict[str, Any]], scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: - logger.debug( - f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions" - ) + logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") res = {} # look up and map each scoring function to its provider impl for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl( - fn_identifier - ).score( + score_response = await self.routing_table.get_provider_impl(fn_identifier).score( input_rows=input_rows, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -668,9 +605,7 @@ class EvalRouter(Eval): scoring_functions: List[str], benchmark_config: BenchmarkConfig, ) -> EvaluateResponse: - logger.debug( - f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows" - ) + logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( benchmark_id=benchmark_id, input_rows=input_rows, @@ -684,9 +619,7 @@ class EvalRouter(Eval): job_id: str, ) -> Job: logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_status( - benchmark_id, job_id - ) + return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) async def job_cancel( self, @@ -740,9 +673,9 @@ class ToolRuntimeRouter(ToolRuntime): logger.debug( f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" ) - return await self.routing_table.get_provider_impl( - "insert_into_memory" - ).insert(documents, vector_db_id, chunk_size_in_tokens) + return await self.routing_table.get_provider_impl("insert_into_memory").insert( + documents, vector_db_id, chunk_size_in_tokens + ) def __init__( self, @@ -775,6 +708,4 @@ class ToolRuntimeRouter(ToolRuntime): self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") - return await self.routing_table.get_provider_impl(tool_group_id).list_tools( - tool_group_id, mcp_endpoint - ) + return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index 84ed8437e..7c28f1bb7 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -89,11 +89,7 @@ class MetaReferenceEvalImpl( all_rows = await self.datasetio_api.iterrows( dataset_id=dataset_id, - limit=( - -1 - if benchmark_config.num_examples is None - else benchmark_config.num_examples - ), + limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples), ) res = await self.evaluate_rows( benchmark_id=benchmark_id, @@ -119,14 +115,10 @@ class MetaReferenceEvalImpl( for i, x in tqdm(enumerate(input_rows)): assert ColumnName.chat_completion_input.value in x, "Invalid input row" input_messages = json.loads(x[ColumnName.chat_completion_input.value]) - input_messages = [ - UserMessage(**x) for x in input_messages if x["role"] == "user" - ] + input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"] # NOTE: only single-turn agent generation is supported. Create a new session for each input row - session_create_response = await self.agents_api.create_agent_session( - agent_id, f"session-{i}" - ) + session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}") session_id = session_create_response.session_id turn_request = dict( @@ -135,12 +127,7 @@ class MetaReferenceEvalImpl( messages=input_messages, stream=True, ) - turn_response = [ - chunk - async for chunk in await self.agents_api.create_agent_turn( - **turn_request - ) - ] + turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)] final_event = turn_response[-1].event.payload # check if there's a memory retrieval step and extract the context @@ -149,14 +136,10 @@ class MetaReferenceEvalImpl( if step.step_type == StepType.tool_execution.value: for tool_response in step.tool_responses: if tool_response.tool_name == MEMORY_QUERY_TOOL: - memory_rag_context = " ".join( - x.text for x in tool_response.content - ) + memory_rag_context = " ".join(x.text for x in tool_response.content) agent_generation = {} - agent_generation[ColumnName.generated_answer.value] = ( - final_event.turn.output_message.content - ) + agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content if memory_rag_context: agent_generation[ColumnName.context.value] = memory_rag_context @@ -168,9 +151,7 @@ class MetaReferenceEvalImpl( self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig ) -> List[Dict[str, Any]]: candidate = benchmark_config.eval_candidate - assert ( - candidate.sampling_params.max_tokens is not None - ), "SamplingParams.max_tokens must be provided" + assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided" generations = [] for x in tqdm(input_rows): @@ -181,39 +162,21 @@ class MetaReferenceEvalImpl( content=input_content, sampling_params=candidate.sampling_params, ) - generations.append( - { - ColumnName.generated_answer.value: response.completion_message.content - } - ) + generations.append({ColumnName.generated_answer.value: response.completion_message.content}) elif ColumnName.chat_completion_input.value in x: - chat_completion_input_json = json.loads( - x[ColumnName.chat_completion_input.value] - ) - input_messages = [ - UserMessage(**x) - for x in chat_completion_input_json - if x["role"] == "user" - ] + chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value]) + input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"] messages = [] if candidate.system_message: messages.append(candidate.system_message) - messages += [ - SystemMessage(**x) - for x in chat_completion_input_json - if x["role"] == "system" - ] + messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"] messages += input_messages response = await self.inference_api.chat_completion( model_id=candidate.model, messages=messages, sampling_params=candidate.sampling_params, ) - generations.append( - { - ColumnName.generated_answer.value: response.completion_message.content - } - ) + generations.append({ColumnName.generated_answer.value: response.completion_message.content}) else: raise ValueError("Invalid input row") @@ -236,8 +199,7 @@ class MetaReferenceEvalImpl( # scoring with generated_answer score_input_rows = [ - input_r | generated_r - for input_r, generated_r in zip(input_rows, generations, strict=False) + input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False) ] if benchmark_config.scoring_params is not None: @@ -246,9 +208,7 @@ class MetaReferenceEvalImpl( for scoring_fn_id in scoring_functions } else: - scoring_functions_dict = { - scoring_fn_id: None for scoring_fn_id in scoring_functions - } + scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions} score_response = await self.scoring_api.score( input_rows=score_input_rows, scoring_functions=scoring_functions_dict