diff --git a/llama_stack/providers/remote/inference/nvidia/__init__.py b/llama_stack/providers/remote/inference/nvidia/__init__.py index 63b466933..99b37a823 100644 --- a/llama_stack/providers/remote/inference/nvidia/__init__.py +++ b/llama_stack/providers/remote/inference/nvidia/__init__.py @@ -4,11 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.inference import Inference + from ._config import NVIDIAConfig -from ._nvidia import NVIDIAInferenceAdapter -async def get_adapter_impl(config: NVIDIAConfig, _deps) -> NVIDIAInferenceAdapter: +async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference: + # import dynamically so `llama stack build` does not fail due to missing dependencies + from ._nvidia import NVIDIAInferenceAdapter + if not isinstance(config, NVIDIAConfig): raise RuntimeError(f"Unexpected config type: {type(config)}") adapter = NVIDIAInferenceAdapter(config)