diff --git a/docs/source/distributions/importing_as_library.md b/docs/source/distributions/importing_as_library.md index fbc48dd95..b9b4b065a 100644 --- a/docs/source/distributions/importing_as_library.md +++ b/docs/source/distributions/importing_as_library.md @@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient( # provider_data is optional, but if you need to pass in any provider specific data, you can do so here. provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]}, ) -client.initialize() ``` This will parse your config and set up any inline implementations and remote clients needed for your implementation. @@ -32,5 +31,4 @@ If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/ ```python client = LlamaStackAsLibraryClient(config_path) -client.initialize() ``` diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index a93fe509e..5561bdaef 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -145,7 +145,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): ): super().__init__() self.async_client = AsyncLlamaStackAsLibraryClient( - config_path_or_distro_name, custom_provider_registry, provider_data + config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal ) self.pool_executor = ThreadPoolExecutor(max_workers=4) self.skip_logger_removal = skip_logger_removal @@ -153,31 +153,19 @@ class LlamaStackAsLibraryClient(LlamaStackClient): self.loop = asyncio.new_event_loop() - def initialize(self): - if in_notebook(): - import nest_asyncio - - nest_asyncio.apply() - if not self.skip_logger_removal: - self._remove_root_logger_handlers() - # use a new event loop to avoid interfering with the main event loop loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - return loop.run_until_complete(self.async_client.initialize()) + loop.run_until_complete(self.async_client.initialize()) finally: asyncio.set_event_loop(None) - def _remove_root_logger_handlers(self): + def initialize(self): """ - Remove all handlers from the root logger. Needed to avoid polluting the console with logs. + Deprecated method for backward compatibility. """ - root_logger = logging.getLogger() - - for handler in root_logger.handlers[:]: - root_logger.removeHandler(handler) - logger.info(f"Removed handler {handler.__class__.__name__} from root logger") + pass def request(self, *args, **kwargs): loop = self.loop @@ -215,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): config_path_or_distro_name: str, custom_provider_registry: ProviderRegistry | None = None, provider_data: dict[str, Any] | None = None, + skip_logger_removal: bool = False, ): super().__init__() # when using the library client, we should not log to console since many @@ -222,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console") + if in_notebook(): + import nest_asyncio + + nest_asyncio.apply() + if not skip_logger_removal: + self._remove_root_logger_handlers() + if config_path_or_distro_name.endswith(".yaml"): config_path = Path(config_path_or_distro_name) if not config_path.exists(): @@ -238,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): self.provider_data = provider_data self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError + def _remove_root_logger_handlers(self): + """ + Remove all handlers from the root logger. Needed to avoid polluting the console with logs. + """ + root_logger = logging.getLogger() + + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + logger.info(f"Removed handler {handler.__class__.__name__} from root logger") + async def initialize(self) -> bool: + """ + Initialize the async client. Can be called multiple times safely. + + Returns: + bool: True if initialization was successful + """ + try: self.route_impls = None self.impls = await construct_stack(self.config, self.custom_provider_registry) @@ -298,9 +311,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): stream=False, stream_cls=None, ): - if self.route_impls is None: - raise ValueError("Client not initialized. Please call initialize() first.") - # Create headers with provider data if available headers = options.headers or {} if self.provider_data: diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index 0b7132d71..703cd7721 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -256,9 +256,7 @@ def instantiate_llama_stack_client(session): provider_data=get_provider_data(), skip_logger_removal=True, ) - if not client.initialize(): - raise RuntimeError("Initialization failed") - + # Client is automatically initialized during construction return client diff --git a/tests/integration/non_ci/responses/fixtures/fixtures.py b/tests/integration/non_ci/responses/fixtures/fixtures.py index 2069010ad..377e2cecd 100644 --- a/tests/integration/non_ci/responses/fixtures/fixtures.py +++ b/tests/integration/non_ci/responses/fixtures/fixtures.py @@ -127,8 +127,7 @@ def openai_client(base_url, api_key, provider): raise ValueError(f"Invalid config for Llama Stack: {provider}, it must be of the form 'stack:'") config = parts[1] client = LlamaStackAsLibraryClient(config, skip_logger_removal=True) - if not client.initialize(): - raise RuntimeError("Initialization failed") + # Client is automatically initialized during construction return client return OpenAI( diff --git a/tests/unit/distribution/test_library_client_initialization.py b/tests/unit/distribution/test_library_client_initialization.py index e510d513d..2108b4676 100644 --- a/tests/unit/distribution/test_library_client_initialization.py +++ b/tests/unit/distribution/test_library_client_initialization.py @@ -5,86 +5,70 @@ # the root directory of this source tree. """ -Unit tests for LlamaStackAsLibraryClient initialization error handling. +Unit tests for LlamaStackAsLibraryClient automatic initialization. -These tests ensure that users get proper error messages when they forget to call -initialize() on the library client, preventing AttributeError regressions. +These tests ensure that the library client is automatically initialized +and ready to use immediately after construction. """ -import pytest - from llama_stack.core.library_client import ( AsyncLlamaStackAsLibraryClient, LlamaStackAsLibraryClient, ) -class TestLlamaStackAsLibraryClientInitialization: - """Test proper error handling for uninitialized library clients.""" +class TestLlamaStackAsLibraryClientAutoInitialization: + """Test automatic initialization of library clients.""" - @pytest.mark.parametrize( - "api_call", - [ - lambda client: client.models.list(), - lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]), - lambda client: next( - client.chat.completions.create( - model="test", messages=[{"role": "user", "content": "test"}], stream=True - ) - ), - ], - ids=["models.list", "chat.completions.create", "chat.completions.create_stream"], - ) - def test_sync_client_proper_error_without_initialization(self, api_call): - """Test that sync client raises ValueError with helpful message when not initialized.""" + def test_sync_client_auto_initialization(self): + """Test that sync client is automatically initialized after construction.""" client = LlamaStackAsLibraryClient("nvidia") - with pytest.raises(ValueError) as exc_info: - api_call(client) + # Client should be automatically initialized + assert client.async_client._is_initialized is True + assert client.async_client.route_impls is not None - error_msg = str(exc_info.value) - assert "Client not initialized" in error_msg - assert "Please call initialize() first" in error_msg - - @pytest.mark.parametrize( - "api_call", - [ - lambda client: client.models.list(), - lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]), - ], - ids=["models.list", "chat.completions.create"], - ) - async def test_async_client_proper_error_without_initialization(self, api_call): - """Test that async client raises ValueError with helpful message when not initialized.""" + async def test_async_client_auto_initialization(self): + """Test that async client can be initialized and works properly.""" client = AsyncLlamaStackAsLibraryClient("nvidia") - with pytest.raises(ValueError) as exc_info: - await api_call(client) + # Initialize the client + result = await client.initialize() + assert result is True + assert client._is_initialized is True + assert client.route_impls is not None - error_msg = str(exc_info.value) - assert "Client not initialized" in error_msg - assert "Please call initialize() first" in error_msg + def test_initialize_method_backward_compatibility(self): + """Test that initialize() method still works for backward compatibility.""" + client = LlamaStackAsLibraryClient("nvidia") - async def test_async_client_streaming_error_without_initialization(self): - """Test that async client streaming raises ValueError with helpful message when not initialized.""" + # initialize() should return None (historical behavior) and not cause errors + result = client.initialize() + assert result is None + + # Multiple calls should be safe + result2 = client.initialize() + assert result2 is None + + async def test_async_initialize_method_idempotent(self): + """Test that async initialize() method can be called multiple times safely.""" client = AsyncLlamaStackAsLibraryClient("nvidia") - with pytest.raises(ValueError) as exc_info: - stream = await client.chat.completions.create( - model="test", messages=[{"role": "user", "content": "test"}], stream=True - ) - await anext(stream) + # First initialization + result1 = await client.initialize() + assert result1 is True + assert client._is_initialized is True - error_msg = str(exc_info.value) - assert "Client not initialized" in error_msg - assert "Please call initialize() first" in error_msg + # Second initialization should be safe and return True + result2 = await client.initialize() + assert result2 is True + assert client._is_initialized is True - def test_route_impls_initialized_to_none(self): - """Test that route_impls is initialized to None to prevent AttributeError.""" - # Test sync client + def test_route_impls_automatically_set(self): + """Test that route_impls is automatically set during construction.""" + # Test sync client - should be auto-initialized sync_client = LlamaStackAsLibraryClient("nvidia") - assert sync_client.async_client.route_impls is None + assert sync_client.async_client.route_impls is not None - # Test async client directly - async_client = AsyncLlamaStackAsLibraryClient("nvidia") - assert async_client.route_impls is None + # Test that the async client is marked as initialized + assert sync_client.async_client._is_initialized is True