diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 589a03b25..df800a6e0 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -105,7 +105,9 @@ class CommonRoutingTableImpl(RoutingTable): self.dist_registry = dist_registry async def initialize(self) -> None: - async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None: + async def add_objects( + objs: List[RoutableObjectWithProvider], provider_id: str, cls + ) -> None: for obj in objs: if cls is None: obj.provider_id = provider_id @@ -140,7 +142,9 @@ class CommonRoutingTableImpl(RoutingTable): for p in self.impls_by_provider_id.values(): await p.shutdown() - def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any: + def get_provider_impl( + self, routing_key: str, provider_id: Optional[str] = None + ) -> Any: def apiname_object(): if isinstance(self, ModelsRoutingTable): return ("Inference", "model") @@ -178,7 +182,9 @@ class CommonRoutingTableImpl(RoutingTable): raise ValueError(f"Provider not found for `{routing_key}`") - async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: + async def get_object_by_identifier( + self, type: str, identifier: str + ) -> Optional[RoutableObjectWithProvider]: # Get from disk registry obj = await self.dist_registry.get(type, identifier) if not obj: @@ -188,9 +194,13 @@ class CommonRoutingTableImpl(RoutingTable): async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: await self.dist_registry.delete(obj.type, obj.identifier) - await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) + await unregister_object_from_provider( + obj, self.impls_by_provider_id[obj.provider_id] + ) - async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: + async def register_object( + self, obj: RoutableObjectWithProvider + ) -> RoutableObjectWithProvider: # if provider_id is not specified, pick an arbitrary one from existing entries if not obj.provider_id and len(self.impls_by_provider_id) > 0: obj.provider_id = list(self.impls_by_provider_id.keys())[0] @@ -245,7 +255,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): if model_type is None: model_type = ModelType.llm if "embedding_dimension" not in metadata and model_type == ModelType.embedding: - raise ValueError("Embedding model must have an embedding dimension in its metadata") + raise ValueError( + "Embedding model must have an embedding dimension in its metadata" + ) model = Model( identifier=model_id, provider_resource_id=provider_model_id, @@ -265,7 +277,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> ListShieldsResponse: - return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) + return ListShieldsResponse( + data=await self.get_all_with_type(ResourceType.shield.value) + ) async def get_shield(self, identifier: str) -> Optional[Shield]: return await self.get_object_by_identifier("shield", identifier) @@ -324,14 +338,18 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." ) else: - raise ValueError("No provider available. Please configure a vector_io provider.") + raise ValueError( + "No provider available. Please configure a vector_io provider." + ) model = await self.get_object_by_identifier("model", embedding_model) if model is None: raise ValueError(f"Model {embedding_model} not found") if model.model_type != ModelType.embedding: raise ValueError(f"Model {embedding_model} is not an embedding model") if "embedding_dimension" not in model.metadata: - raise ValueError(f"Model {embedding_model} does not have an embedding dimension") + raise ValueError( + f"Model {embedding_model} does not have an embedding dimension" + ) vector_db_data = { "identifier": vector_db_id, "type": ResourceType.vector_db.value, @@ -353,7 +371,9 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> ListDatasetsResponse: - return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) + return ListDatasetsResponse( + data=await self.get_all_with_type(ResourceType.dataset.value) + ) async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: return await self.get_object_by_identifier("dataset", dataset_id) @@ -371,9 +391,9 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): provider_dataset_id = dataset_id # infer provider from source - if source.type == DatasetType.rows: + if source.type == DatasetType.rows.value: provider_id = "localfs" - elif source.type == DatasetType.uri: + elif source.type == DatasetType.uri.value: # infer provider from uri if source.uri.startswith("huggingface"): provider_id = "huggingface" @@ -406,7 +426,9 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> ListScoringFunctionsResponse: - return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) + return ListScoringFunctionsResponse( + data=await self.get_all_with_type(ResourceType.scoring_function.value) + ) async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: return await self.get_object_by_identifier("scoring_function", scoring_fn_id) @@ -503,8 +525,12 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): args: Optional[Dict[str, Any]] = None, ) -> None: tools = [] - tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) - tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution + tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools( + toolgroup_id, mcp_endpoint + ) + tool_host = ( + ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution + ) for tool_def in tool_defs: tools.append( diff --git a/tests/integration/datasets/test_datasets.py b/tests/integration/datasets/test_datasets.py new file mode 100644 index 000000000..38abb54c9 --- /dev/null +++ b/tests/integration/datasets/test_datasets.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import base64 +import mimetypes +import os +from pathlib import Path + +import pytest + +# How to run this test: +# +# LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasets + + +def test_register_dataset(llama_stack_client): + dataset = llama_stack_client.datasets.register( + purpose="eval/messages-answer", + source={"type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"}, + ) + print(dataset) diff --git a/tests/integration/datasets/test_script.py b/tests/integration/datasets/test_script.py new file mode 100644 index 000000000..a3f5f626b --- /dev/null +++ b/tests/integration/datasets/test_script.py @@ -0,0 +1,15 @@ +from llama_stack_client import LlamaStackClient +from rich.pretty import pprint + + +def test_register_dataset(): + client = LlamaStackClient(base_url="http://localhost:8321") + dataset = client.datasets.register( + purpose="eval/messages-answer", + source={"type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"}, + ) + pprint(dataset) + + +if __name__ == "__main__": + test_register_dataset()