diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 22a1e46f9..875c8c94e 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -8,11 +8,11 @@ import time from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from llama_stack.apis.common.content_types import ( - URL, InterleavedContent, InterleavedContentItem, + URL, ) -from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult +from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.eval import ( BenchmarkConfig, Eval, @@ -93,7 +93,9 @@ class VectorIORouter(VectorIO): provider_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None, ) -> None: - logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") + logger.debug( + f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}" + ) await self.routing_table.register_vector_db( vector_db_id, embedding_model, @@ -111,7 +113,9 @@ class VectorIORouter(VectorIO): logger.debug( f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", ) - return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) + return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks( + vector_db_id, chunks, ttl_seconds + ) async def query_chunks( self, @@ -120,7 +124,9 @@ class VectorIORouter(VectorIO): params: Optional[Dict[str, Any]] = None, ) -> QueryChunksResponse: logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") - return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) + return await self.routing_table.get_provider_impl(vector_db_id).query_chunks( + vector_db_id, query, params + ) class InferenceRouter(Inference): @@ -157,10 +163,16 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", ) - await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) + await self.routing_table.register_model( + model_id, provider_model_id, provider_id, metadata, model_type + ) def _construct_metrics( - self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model + self, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + model: Model, ) -> List[MetricEvent]: """Constructs a list of MetricEvent objects containing token usage metrics. @@ -207,11 +219,16 @@ class InferenceRouter(Inference): total_tokens: int, model: Model, ) -> List[MetricInResponse]: - metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) + metrics = self._construct_metrics( + prompt_tokens, completion_tokens, total_tokens, model + ) if self.telemetry: for metric in metrics: await self.telemetry.log_event(metric) - return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] + return [ + MetricInResponse(metric=metric.metric, value=metric.value) + for metric in metrics + ] async def _count_tokens( self, @@ -236,7 +253,9 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + ) -> Union[ + ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] + ]: logger.debug( f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) @@ -246,12 +265,19 @@ class InferenceRouter(Inference): if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: - raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") + raise ValueError( + f"Model '{model_id}' is an embedding model and does not support chat completions" + ) if tool_config: if tool_choice and tool_choice != tool_config.tool_choice: raise ValueError("tool_choice and tool_config.tool_choice must match") - if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format: - raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") + if ( + tool_prompt_format + and tool_prompt_format != tool_config.tool_prompt_format + ): + raise ValueError( + "tool_prompt_format and tool_config.tool_prompt_format must match" + ) else: params = {} if tool_choice: @@ -269,9 +295,14 @@ class InferenceRouter(Inference): pass else: # verify tool_choice is one of the tools - tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] + tool_names = [ + t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value + for t in tools + ] if tool_config.tool_choice not in tool_names: - raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") + raise ValueError( + f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}" + ) params = dict( model_id=model_id, @@ -286,19 +317,32 @@ class InferenceRouter(Inference): tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) - prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) + prompt_tokens = await self._count_tokens( + messages, tool_config.tool_prompt_format + ) if stream: async def stream_generator(): completion_text = "" async for chunk in await provider.chat_completion(**params): - if chunk.event.event_type == ChatCompletionResponseEventType.progress: + if ( + chunk.event.event_type + == ChatCompletionResponseEventType.progress + ): if chunk.event.delta.type == "text": completion_text += chunk.event.delta.text - if chunk.event.event_type == ChatCompletionResponseEventType.complete: + if ( + chunk.event.event_type + == ChatCompletionResponseEventType.complete + ): completion_tokens = await self._count_tokens( - [CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)], + [ + CompletionMessage( + content=completion_text, + stop_reason=StopReason.end_of_turn, + ) + ], tool_config.tool_prompt_format, ) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) @@ -308,7 +352,11 @@ class InferenceRouter(Inference): total_tokens, model, ) - chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics + chunk.metrics = ( + metrics + if chunk.metrics is None + else chunk.metrics + metrics + ) yield chunk return stream_generator() @@ -325,7 +373,9 @@ class InferenceRouter(Inference): total_tokens, model, ) - response.metrics = metrics if response.metrics is None else response.metrics + metrics + response.metrics = ( + metrics if response.metrics is None else response.metrics + metrics + ) return response async def completion( @@ -346,7 +396,9 @@ class InferenceRouter(Inference): if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: - raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") + raise ValueError( + f"Model '{model_id}' is an embedding model and does not support chat completions" + ) provider = self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, @@ -366,7 +418,11 @@ class InferenceRouter(Inference): async for chunk in await provider.completion(**params): if hasattr(chunk, "delta"): completion_text += chunk.delta - if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: + if ( + hasattr(chunk, "stop_reason") + and chunk.stop_reason + and self.telemetry + ): completion_tokens = await self._count_tokens(completion_text) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) metrics = await self._compute_and_log_token_usage( @@ -375,7 +431,11 @@ class InferenceRouter(Inference): total_tokens, model, ) - chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics + chunk.metrics = ( + metrics + if chunk.metrics is None + else chunk.metrics + metrics + ) yield chunk return stream_generator() @@ -389,7 +449,9 @@ class InferenceRouter(Inference): total_tokens, model, ) - response.metrics = metrics if response.metrics is None else response.metrics + metrics + response.metrics = ( + metrics if response.metrics is None else response.metrics + metrics + ) return response async def embeddings( @@ -405,7 +467,9 @@ class InferenceRouter(Inference): if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.llm: - raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") + raise ValueError( + f"Model '{model_id}' is an LLM model and does not support embeddings" + ) return await self.routing_table.get_provider_impl(model_id).embeddings( model_id=model_id, contents=contents, @@ -439,7 +503,9 @@ class SafetyRouter(Safety): params: Optional[Dict[str, Any]] = None, ) -> Shield: logger.debug(f"SafetyRouter.register_shield: {shield_id}") - return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) + return await self.routing_table.register_shield( + shield_id, provider_shield_id, provider_id, params + ) async def run_shield( self, @@ -477,11 +543,13 @@ class DatasetIORouter(DatasetIO): rows_in_page: int, page_token: Optional[str] = None, filter_condition: Optional[str] = None, - ) -> PaginatedRowsResult: + ) -> IterrowsResponse: logger.debug( f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}", ) - return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated( + return await self.routing_table.get_provider_impl( + dataset_id + ).get_rows_paginated( dataset_id=dataset_id, rows_in_page=rows_in_page, page_token=page_token, @@ -521,7 +589,9 @@ class ScoringRouter(Scoring): logger.debug(f"ScoringRouter.score_batch: {dataset_id}") res = {} for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( + score_response = await self.routing_table.get_provider_impl( + fn_identifier + ).score_batch( dataset_id=dataset_id, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -539,11 +609,15 @@ class ScoringRouter(Scoring): input_rows: List[Dict[str, Any]], scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: - logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") + logger.debug( + f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions" + ) res = {} # look up and map each scoring function to its provider impl for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score( + score_response = await self.routing_table.get_provider_impl( + fn_identifier + ).score( input_rows=input_rows, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -586,7 +660,9 @@ class EvalRouter(Eval): scoring_functions: List[str], benchmark_config: BenchmarkConfig, ) -> EvaluateResponse: - logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") + logger.debug( + f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows" + ) return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( benchmark_id=benchmark_id, input_rows=input_rows, @@ -600,7 +676,9 @@ class EvalRouter(Eval): job_id: str, ) -> Optional[JobStatus]: logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) + return await self.routing_table.get_provider_impl(benchmark_id).job_status( + benchmark_id, job_id + ) async def job_cancel( self, @@ -654,9 +732,9 @@ class ToolRuntimeRouter(ToolRuntime): logger.debug( f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" ) - return await self.routing_table.get_provider_impl("insert_into_memory").insert( - documents, vector_db_id, chunk_size_in_tokens - ) + return await self.routing_table.get_provider_impl( + "insert_into_memory" + ).insert(documents, vector_db_id, chunk_size_in_tokens) def __init__( self, @@ -689,4 +767,6 @@ class ToolRuntimeRouter(ToolRuntime): self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") - return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) + return await self.routing_table.get_provider_impl(tool_group_id).list_tools( + tool_group_id, mcp_endpoint + ) diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index c5216e026..03dbae337 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -13,7 +13,7 @@ from urllib.parse import urlparse import pandas from llama_stack.apis.common.content_types import URL -from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult +from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import Dataset from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url @@ -134,7 +134,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): rows_in_page: int, page_token: Optional[str] = None, filter_condition: Optional[str] = None, - ) -> PaginatedRowsResult: + ) -> IterrowsResponse: dataset_info = self.dataset_infos.get(dataset_id) dataset_info.dataset_impl.load() @@ -154,7 +154,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): rows = dataset_info.dataset_impl[start:end] - return PaginatedRowsResult( + return IterrowsResponse( rows=rows, total_count=len(rows), next_page_token=str(end), @@ -170,7 +170,9 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): new_rows_df = pandas.DataFrame(rows) new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df) - dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True) + dataset_impl.df = pandas.concat( + [dataset_impl.df, new_rows_df], ignore_index=True + ) url = str(dataset_info.dataset_def.url.uri) parsed_url = urlparse(url) @@ -185,8 +187,12 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): raise ValueError("Data URL must be a base64-encoded CSV") csv_buffer = dataset_impl.df.to_csv(index=False) - base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8") - dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}") + base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode( + "utf-8" + ) + dataset_info.dataset_def.url = URL( + uri=f"data:text/csv;base64,{base64_content}" + ) else: raise ValueError( f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing." diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index cd4e7f1f1..8df64a190 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional import datasets as hf_datasets -from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult +from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import Dataset from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url @@ -79,7 +79,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): rows_in_page: int, page_token: Optional[str] = None, filter_condition: Optional[str] = None, - ) -> PaginatedRowsResult: + ) -> IterrowsResponse: dataset_def = self.dataset_infos[dataset_id] loaded_dataset = load_hf_dataset(dataset_def) @@ -99,7 +99,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): rows = [loaded_dataset[i] for i in range(start, end)] - return PaginatedRowsResult( + return IterrowsResponse( rows=rows, total_count=len(rows), next_page_token=str(end), @@ -113,9 +113,13 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): new_dataset = hf_datasets.Dataset.from_list(rows) # Concatenate the new rows with existing dataset - updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset]) + updated_dataset = hf_datasets.concatenate_datasets( + [loaded_dataset, new_dataset] + ) if dataset_def.metadata.get("path", None): updated_dataset.push_to_hub(dataset_def.metadata["path"]) else: - raise NotImplementedError("Uploading to URL-based datasets is not supported yet") + raise NotImplementedError( + "Uploading to URL-based datasets is not supported yet" + )