From fdfc37a8788f0d4ad373524384752254084a082a Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 11 Nov 2024 12:02:17 -0500 Subject: [PATCH] huggingface -> remote adapter --- .../adapters/datasetio/huggingface/__init__.py | 2 +- llama_stack/providers/registry/datasetio.py | 15 +++++++++------ llama_stack/providers/tests/datasetio/fixtures.py | 2 +- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/llama_stack/providers/adapters/datasetio/huggingface/__init__.py b/llama_stack/providers/adapters/datasetio/huggingface/__init__.py index cda67177c..db803d183 100644 --- a/llama_stack/providers/adapters/datasetio/huggingface/__init__.py +++ b/llama_stack/providers/adapters/datasetio/huggingface/__init__.py @@ -7,7 +7,7 @@ from .config import HuggingfaceDatasetIOConfig -async def get_provider_impl( +async def get_adapter_impl( config: HuggingfaceDatasetIOConfig, _deps, ): diff --git a/llama_stack/providers/registry/datasetio.py b/llama_stack/providers/registry/datasetio.py index 82f90e022..3fdeac997 100644 --- a/llama_stack/providers/registry/datasetio.py +++ b/llama_stack/providers/registry/datasetio.py @@ -19,12 +19,15 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.inline.meta_reference.datasetio.MetaReferenceDatasetIOConfig", api_dependencies=[], ), - InlineProviderSpec( + remote_provider_spec( api=Api.datasetio, - provider_type="huggingface", - pip_packages=["datasets"], - module="llama_stack.providers.adapters.datasetio.huggingface", - config_class="llama_stack.providers.adapters.datasetio.huggingface.HuggingfaceDatasetIOConfig", - api_dependencies=[], + adapter=AdapterSpec( + adapter_type="huggingface", + pip_packages=[ + "datasets", + ], + module="llama_stack.providers.adapters.datasetio.huggingface", + config_class="llama_stack.providers.adapters.datasetio.huggingface.HuggingfaceDatasetIOConfig", + ), ), ] diff --git a/llama_stack/providers/tests/datasetio/fixtures.py b/llama_stack/providers/tests/datasetio/fixtures.py index 7a389e4d1..d810d5e02 100644 --- a/llama_stack/providers/tests/datasetio/fixtures.py +++ b/llama_stack/providers/tests/datasetio/fixtures.py @@ -37,7 +37,7 @@ def datasetio_huggingface() -> ProviderFixture: providers=[ Provider( provider_id="huggingface", - provider_type="huggingface", + provider_type="remote::huggingface", config={}, ) ],