From cf225c9710e202331814c4a4f783939b2db7b2c5 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 15 Mar 2025 16:20:58 -0700 Subject: [PATCH] precommit --- llama_stack/apis/datasetio/datasetio.py | 4 +- llama_stack/distribution/routers/routers.py | 135 +++++------------- .../distribution/routers/routing_tables.py | 52 ++----- .../datasetio/huggingface/huggingface.py | 11 +- tests/integration/datasets/test_datasets.py | 8 +- tests/integration/datasets/test_script.py | 6 + 6 files changed, 56 insertions(+), 160 deletions(-) diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index 8545a7189..6079e5b99 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -51,6 +51,4 @@ class DatasetIO(Protocol): ... @webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST") - async def append_rows( - self, dataset_id: str, rows: List[Dict[str, Any]] - ) -> None: ... + async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 879fc924b..2cf38f544 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -8,9 +8,9 @@ import time from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from llama_stack.apis.common.content_types import ( + URL, InterleavedContent, InterleavedContentItem, - URL, ) from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import DatasetPurpose, DataSource @@ -94,9 +94,7 @@ class VectorIORouter(VectorIO): provider_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None, ) -> None: - logger.debug( - f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}" - ) + logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") await self.routing_table.register_vector_db( vector_db_id, embedding_model, @@ -114,9 +112,7 @@ class VectorIORouter(VectorIO): logger.debug( f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", ) - return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks( - vector_db_id, chunks, ttl_seconds - ) + return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) async def query_chunks( self, @@ -125,9 +121,7 @@ class VectorIORouter(VectorIO): params: Optional[Dict[str, Any]] = None, ) -> QueryChunksResponse: logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") - return await self.routing_table.get_provider_impl(vector_db_id).query_chunks( - vector_db_id, query, params - ) + return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) class InferenceRouter(Inference): @@ -164,9 +158,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", ) - await self.routing_table.register_model( - model_id, provider_model_id, provider_id, metadata, model_type - ) + await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) def _construct_metrics( self, @@ -220,16 +212,11 @@ class InferenceRouter(Inference): total_tokens: int, model: Model, ) -> List[MetricInResponse]: - metrics = self._construct_metrics( - prompt_tokens, completion_tokens, total_tokens, model - ) + metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) if self.telemetry: for metric in metrics: await self.telemetry.log_event(metric) - return [ - MetricInResponse(metric=metric.metric, value=metric.value) - for metric in metrics - ] + return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] async def _count_tokens( self, @@ -254,9 +241,7 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> Union[ - ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] - ]: + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: logger.debug( f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) @@ -266,19 +251,12 @@ class InferenceRouter(Inference): if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: - raise ValueError( - f"Model '{model_id}' is an embedding model and does not support chat completions" - ) + raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") if tool_config: if tool_choice and tool_choice != tool_config.tool_choice: raise ValueError("tool_choice and tool_config.tool_choice must match") - if ( - tool_prompt_format - and tool_prompt_format != tool_config.tool_prompt_format - ): - raise ValueError( - "tool_prompt_format and tool_config.tool_prompt_format must match" - ) + if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format: + raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") else: params = {} if tool_choice: @@ -296,14 +274,9 @@ class InferenceRouter(Inference): pass else: # verify tool_choice is one of the tools - tool_names = [ - t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value - for t in tools - ] + tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] if tool_config.tool_choice not in tool_names: - raise ValueError( - f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}" - ) + raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") params = dict( model_id=model_id, @@ -318,25 +291,17 @@ class InferenceRouter(Inference): tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) - prompt_tokens = await self._count_tokens( - messages, tool_config.tool_prompt_format - ) + prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) if stream: async def stream_generator(): completion_text = "" async for chunk in await provider.chat_completion(**params): - if ( - chunk.event.event_type - == ChatCompletionResponseEventType.progress - ): + if chunk.event.event_type == ChatCompletionResponseEventType.progress: if chunk.event.delta.type == "text": completion_text += chunk.event.delta.text - if ( - chunk.event.event_type - == ChatCompletionResponseEventType.complete - ): + if chunk.event.event_type == ChatCompletionResponseEventType.complete: completion_tokens = await self._count_tokens( [ CompletionMessage( @@ -353,11 +318,7 @@ class InferenceRouter(Inference): total_tokens, model, ) - chunk.metrics = ( - metrics - if chunk.metrics is None - else chunk.metrics + metrics - ) + chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics yield chunk return stream_generator() @@ -374,9 +335,7 @@ class InferenceRouter(Inference): total_tokens, model, ) - response.metrics = ( - metrics if response.metrics is None else response.metrics + metrics - ) + response.metrics = metrics if response.metrics is None else response.metrics + metrics return response async def completion( @@ -397,9 +356,7 @@ class InferenceRouter(Inference): if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: - raise ValueError( - f"Model '{model_id}' is an embedding model and does not support chat completions" - ) + raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") provider = self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, @@ -419,11 +376,7 @@ class InferenceRouter(Inference): async for chunk in await provider.completion(**params): if hasattr(chunk, "delta"): completion_text += chunk.delta - if ( - hasattr(chunk, "stop_reason") - and chunk.stop_reason - and self.telemetry - ): + if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: completion_tokens = await self._count_tokens(completion_text) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) metrics = await self._compute_and_log_token_usage( @@ -432,11 +385,7 @@ class InferenceRouter(Inference): total_tokens, model, ) - chunk.metrics = ( - metrics - if chunk.metrics is None - else chunk.metrics + metrics - ) + chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics yield chunk return stream_generator() @@ -450,9 +399,7 @@ class InferenceRouter(Inference): total_tokens, model, ) - response.metrics = ( - metrics if response.metrics is None else response.metrics + metrics - ) + response.metrics = metrics if response.metrics is None else response.metrics + metrics return response async def embeddings( @@ -468,9 +415,7 @@ class InferenceRouter(Inference): if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.llm: - raise ValueError( - f"Model '{model_id}' is an LLM model and does not support embeddings" - ) + raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") return await self.routing_table.get_provider_impl(model_id).embeddings( model_id=model_id, contents=contents, @@ -504,9 +449,7 @@ class SafetyRouter(Safety): params: Optional[Dict[str, Any]] = None, ) -> Shield: logger.debug(f"SafetyRouter.register_shield: {shield_id}") - return await self.routing_table.register_shield( - shield_id, provider_shield_id, provider_id, params - ) + return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) async def run_shield( self, @@ -603,9 +546,7 @@ class ScoringRouter(Scoring): logger.debug(f"ScoringRouter.score_batch: {dataset_id}") res = {} for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl( - fn_identifier - ).score_batch( + score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( dataset_id=dataset_id, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -623,15 +564,11 @@ class ScoringRouter(Scoring): input_rows: List[Dict[str, Any]], scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: - logger.debug( - f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions" - ) + logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") res = {} # look up and map each scoring function to its provider impl for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl( - fn_identifier - ).score( + score_response = await self.routing_table.get_provider_impl(fn_identifier).score( input_rows=input_rows, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -674,9 +611,7 @@ class EvalRouter(Eval): scoring_functions: List[str], benchmark_config: BenchmarkConfig, ) -> EvaluateResponse: - logger.debug( - f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows" - ) + logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( benchmark_id=benchmark_id, input_rows=input_rows, @@ -690,9 +625,7 @@ class EvalRouter(Eval): job_id: str, ) -> Optional[JobStatus]: logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_status( - benchmark_id, job_id - ) + return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) async def job_cancel( self, @@ -746,9 +679,9 @@ class ToolRuntimeRouter(ToolRuntime): logger.debug( f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" ) - return await self.routing_table.get_provider_impl( - "insert_into_memory" - ).insert(documents, vector_db_id, chunk_size_in_tokens) + return await self.routing_table.get_provider_impl("insert_into_memory").insert( + documents, vector_db_id, chunk_size_in_tokens + ) def __init__( self, @@ -781,6 +714,4 @@ class ToolRuntimeRouter(ToolRuntime): self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") - return await self.routing_table.get_provider_impl(tool_group_id).list_tools( - tool_group_id, mcp_endpoint - ) + return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index df800a6e0..533993421 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -105,9 +105,7 @@ class CommonRoutingTableImpl(RoutingTable): self.dist_registry = dist_registry async def initialize(self) -> None: - async def add_objects( - objs: List[RoutableObjectWithProvider], provider_id: str, cls - ) -> None: + async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None: for obj in objs: if cls is None: obj.provider_id = provider_id @@ -142,9 +140,7 @@ class CommonRoutingTableImpl(RoutingTable): for p in self.impls_by_provider_id.values(): await p.shutdown() - def get_provider_impl( - self, routing_key: str, provider_id: Optional[str] = None - ) -> Any: + def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any: def apiname_object(): if isinstance(self, ModelsRoutingTable): return ("Inference", "model") @@ -182,9 +178,7 @@ class CommonRoutingTableImpl(RoutingTable): raise ValueError(f"Provider not found for `{routing_key}`") - async def get_object_by_identifier( - self, type: str, identifier: str - ) -> Optional[RoutableObjectWithProvider]: + async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: # Get from disk registry obj = await self.dist_registry.get(type, identifier) if not obj: @@ -194,13 +188,9 @@ class CommonRoutingTableImpl(RoutingTable): async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: await self.dist_registry.delete(obj.type, obj.identifier) - await unregister_object_from_provider( - obj, self.impls_by_provider_id[obj.provider_id] - ) + await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) - async def register_object( - self, obj: RoutableObjectWithProvider - ) -> RoutableObjectWithProvider: + async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: # if provider_id is not specified, pick an arbitrary one from existing entries if not obj.provider_id and len(self.impls_by_provider_id) > 0: obj.provider_id = list(self.impls_by_provider_id.keys())[0] @@ -255,9 +245,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): if model_type is None: model_type = ModelType.llm if "embedding_dimension" not in metadata and model_type == ModelType.embedding: - raise ValueError( - "Embedding model must have an embedding dimension in its metadata" - ) + raise ValueError("Embedding model must have an embedding dimension in its metadata") model = Model( identifier=model_id, provider_resource_id=provider_model_id, @@ -277,9 +265,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> ListShieldsResponse: - return ListShieldsResponse( - data=await self.get_all_with_type(ResourceType.shield.value) - ) + return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) async def get_shield(self, identifier: str) -> Optional[Shield]: return await self.get_object_by_identifier("shield", identifier) @@ -338,18 +324,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." ) else: - raise ValueError( - "No provider available. Please configure a vector_io provider." - ) + raise ValueError("No provider available. Please configure a vector_io provider.") model = await self.get_object_by_identifier("model", embedding_model) if model is None: raise ValueError(f"Model {embedding_model} not found") if model.model_type != ModelType.embedding: raise ValueError(f"Model {embedding_model} is not an embedding model") if "embedding_dimension" not in model.metadata: - raise ValueError( - f"Model {embedding_model} does not have an embedding dimension" - ) + raise ValueError(f"Model {embedding_model} does not have an embedding dimension") vector_db_data = { "identifier": vector_db_id, "type": ResourceType.vector_db.value, @@ -371,9 +353,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> ListDatasetsResponse: - return ListDatasetsResponse( - data=await self.get_all_with_type(ResourceType.dataset.value) - ) + return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: return await self.get_object_by_identifier("dataset", dataset_id) @@ -426,9 +406,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> ListScoringFunctionsResponse: - return ListScoringFunctionsResponse( - data=await self.get_all_with_type(ResourceType.scoring_function.value) - ) + return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: return await self.get_object_by_identifier("scoring_function", scoring_fn_id) @@ -525,12 +503,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): args: Optional[Dict[str, Any]] = None, ) -> None: tools = [] - tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools( - toolgroup_id, mcp_endpoint - ) - tool_host = ( - ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution - ) + tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) + tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution for tool_def in tool_defs: tools.append( diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index 82a76f8bc..23a6cffb5 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from typing import Any, Dict, List, Optional - from urllib.parse import parse_qs, urlparse import datasets as hf_datasets @@ -12,13 +11,11 @@ import datasets as hf_datasets from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import Dataset from llama_stack.providers.datatypes import DatasetsProtocolPrivate -from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.kvstore import kvstore_impl from .config import HuggingfaceDatasetIOConfig DATASETS_PREFIX = "datasets:" -from rich.pretty import pprint def parse_hf_params(dataset_def: Dataset): @@ -102,13 +99,9 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): new_dataset = hf_datasets.Dataset.from_list(rows) # Concatenate the new rows with existing dataset - updated_dataset = hf_datasets.concatenate_datasets( - [loaded_dataset, new_dataset] - ) + updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset]) if dataset_def.metadata.get("path", None): updated_dataset.push_to_hub(dataset_def.metadata["path"]) else: - raise NotImplementedError( - "Uploading to URL-based datasets is not supported yet" - ) + raise NotImplementedError("Uploading to URL-based datasets is not supported yet") diff --git a/tests/integration/datasets/test_datasets.py b/tests/integration/datasets/test_datasets.py index be02705bc..fdae5420c 100644 --- a/tests/integration/datasets/test_datasets.py +++ b/tests/integration/datasets/test_datasets.py @@ -4,10 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import base64 -import mimetypes -import os -from pathlib import Path import pytest @@ -36,8 +32,6 @@ def test_register_dataset(llama_stack_client, purpose, source, provider_id): ) assert dataset.identifier is not None assert dataset.provider_id == provider_id - iterrow_response = llama_stack_client.datasets.iterrows( - dataset.identifier, limit=10 - ) + iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier, limit=10) assert len(iterrow_response.data) == 10 assert iterrow_response.next_index is not None diff --git a/tests/integration/datasets/test_script.py b/tests/integration/datasets/test_script.py index afdde25ca..99bd27460 100644 --- a/tests/integration/datasets/test_script.py +++ b/tests/integration/datasets/test_script.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + from llama_stack_client import LlamaStackClient from rich.pretty import pprint