From 7d23fe3215eb69f1825ea2d3244750565664f56a Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Thu, 31 Jul 2025 14:18:45 +0200 Subject: [PATCH] refactor(client): remove initialize() Method from LlamaStackAsLibrary Currently client.initialize() had to be invoked by user. To improve dev experience and to avoid runtime errors, this PR init LlamaStackAsLibrary implicitly upon using the client. It prevents also multiple init of the same client, while maintaining backward ccompatibility. Signed-off-by: Mustafa Elbehery --- .../distributions/importing_as_library.md | 2 - llama_stack/core/library_client.py | 50 +++++---- tests/integration/fixtures/common.py | 4 +- .../non_ci/responses/fixtures/fixtures.py | 3 +- .../test_library_client_initialization.py | 104 ++++++++---------- 5 files changed, 76 insertions(+), 87 deletions(-) diff --git a/docs/source/distributions/importing_as_library.md b/docs/source/distributions/importing_as_library.md index fbc48dd95..b9b4b065a 100644 --- a/docs/source/distributions/importing_as_library.md +++ b/docs/source/distributions/importing_as_library.md @@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient( # provider_data is optional, but if you need to pass in any provider specific data, you can do so here. provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]}, ) -client.initialize() ``` This will parse your config and set up any inline implementations and remote clients needed for your implementation. @@ -32,5 +31,4 @@ If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/ ```python client = LlamaStackAsLibraryClient(config_path) -client.initialize() ``` diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index a93fe509e..5561bdaef 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -145,7 +145,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): ): super().__init__() self.async_client = AsyncLlamaStackAsLibraryClient( - config_path_or_distro_name, custom_provider_registry, provider_data + config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal ) self.pool_executor = ThreadPoolExecutor(max_workers=4) self.skip_logger_removal = skip_logger_removal @@ -153,31 +153,19 @@ class LlamaStackAsLibraryClient(LlamaStackClient): self.loop = asyncio.new_event_loop() - def initialize(self): - if in_notebook(): - import nest_asyncio - - nest_asyncio.apply() - if not self.skip_logger_removal: - self._remove_root_logger_handlers() - # use a new event loop to avoid interfering with the main event loop loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - return loop.run_until_complete(self.async_client.initialize()) + loop.run_until_complete(self.async_client.initialize()) finally: asyncio.set_event_loop(None) - def _remove_root_logger_handlers(self): + def initialize(self): """ - Remove all handlers from the root logger. Needed to avoid polluting the console with logs. + Deprecated method for backward compatibility. """ - root_logger = logging.getLogger() - - for handler in root_logger.handlers[:]: - root_logger.removeHandler(handler) - logger.info(f"Removed handler {handler.__class__.__name__} from root logger") + pass def request(self, *args, **kwargs): loop = self.loop @@ -215,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): config_path_or_distro_name: str, custom_provider_registry: ProviderRegistry | None = None, provider_data: dict[str, Any] | None = None, + skip_logger_removal: bool = False, ): super().__init__() # when using the library client, we should not log to console since many @@ -222,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console") + if in_notebook(): + import nest_asyncio + + nest_asyncio.apply() + if not skip_logger_removal: + self._remove_root_logger_handlers() + if config_path_or_distro_name.endswith(".yaml"): config_path = Path(config_path_or_distro_name) if not config_path.exists(): @@ -238,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): self.provider_data = provider_data self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError + def _remove_root_logger_handlers(self): + """ + Remove all handlers from the root logger. Needed to avoid polluting the console with logs. + """ + root_logger = logging.getLogger() + + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + logger.info(f"Removed handler {handler.__class__.__name__} from root logger") + async def initialize(self) -> bool: + """ + Initialize the async client. Can be called multiple times safely. + + Returns: + bool: True if initialization was successful + """ + try: self.route_impls = None self.impls = await construct_stack(self.config, self.custom_provider_registry) @@ -298,9 +311,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): stream=False, stream_cls=None, ): - if self.route_impls is None: - raise ValueError("Client not initialized. Please call initialize() first.") - # Create headers with provider data if available headers = options.headers or {} if self.provider_data: diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index 0b7132d71..703cd7721 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -256,9 +256,7 @@ def instantiate_llama_stack_client(session): provider_data=get_provider_data(), skip_logger_removal=True, ) - if not client.initialize(): - raise RuntimeError("Initialization failed") - + # Client is automatically initialized during construction return client diff --git a/tests/integration/non_ci/responses/fixtures/fixtures.py b/tests/integration/non_ci/responses/fixtures/fixtures.py index 62c4ae086..9ec0d8ada 100644 --- a/tests/integration/non_ci/responses/fixtures/fixtures.py +++ b/tests/integration/non_ci/responses/fixtures/fixtures.py @@ -113,8 +113,7 @@ def openai_client(base_url, api_key, provider): raise ValueError(f"Invalid config for Llama Stack: {provider}, it must be of the form 'stack:'") config = parts[1] client = LlamaStackAsLibraryClient(config, skip_logger_removal=True) - if not client.initialize(): - raise RuntimeError("Initialization failed") + # Client is automatically initialized during construction return client return OpenAI( diff --git a/tests/unit/distribution/test_library_client_initialization.py b/tests/unit/distribution/test_library_client_initialization.py index e510d513d..2108b4676 100644 --- a/tests/unit/distribution/test_library_client_initialization.py +++ b/tests/unit/distribution/test_library_client_initialization.py @@ -5,86 +5,70 @@ # the root directory of this source tree. """ -Unit tests for LlamaStackAsLibraryClient initialization error handling. +Unit tests for LlamaStackAsLibraryClient automatic initialization. -These tests ensure that users get proper error messages when they forget to call -initialize() on the library client, preventing AttributeError regressions. +These tests ensure that the library client is automatically initialized +and ready to use immediately after construction. """ -import pytest - from llama_stack.core.library_client import ( AsyncLlamaStackAsLibraryClient, LlamaStackAsLibraryClient, ) -class TestLlamaStackAsLibraryClientInitialization: - """Test proper error handling for uninitialized library clients.""" +class TestLlamaStackAsLibraryClientAutoInitialization: + """Test automatic initialization of library clients.""" - @pytest.mark.parametrize( - "api_call", - [ - lambda client: client.models.list(), - lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]), - lambda client: next( - client.chat.completions.create( - model="test", messages=[{"role": "user", "content": "test"}], stream=True - ) - ), - ], - ids=["models.list", "chat.completions.create", "chat.completions.create_stream"], - ) - def test_sync_client_proper_error_without_initialization(self, api_call): - """Test that sync client raises ValueError with helpful message when not initialized.""" + def test_sync_client_auto_initialization(self): + """Test that sync client is automatically initialized after construction.""" client = LlamaStackAsLibraryClient("nvidia") - with pytest.raises(ValueError) as exc_info: - api_call(client) + # Client should be automatically initialized + assert client.async_client._is_initialized is True + assert client.async_client.route_impls is not None - error_msg = str(exc_info.value) - assert "Client not initialized" in error_msg - assert "Please call initialize() first" in error_msg - - @pytest.mark.parametrize( - "api_call", - [ - lambda client: client.models.list(), - lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]), - ], - ids=["models.list", "chat.completions.create"], - ) - async def test_async_client_proper_error_without_initialization(self, api_call): - """Test that async client raises ValueError with helpful message when not initialized.""" + async def test_async_client_auto_initialization(self): + """Test that async client can be initialized and works properly.""" client = AsyncLlamaStackAsLibraryClient("nvidia") - with pytest.raises(ValueError) as exc_info: - await api_call(client) + # Initialize the client + result = await client.initialize() + assert result is True + assert client._is_initialized is True + assert client.route_impls is not None - error_msg = str(exc_info.value) - assert "Client not initialized" in error_msg - assert "Please call initialize() first" in error_msg + def test_initialize_method_backward_compatibility(self): + """Test that initialize() method still works for backward compatibility.""" + client = LlamaStackAsLibraryClient("nvidia") - async def test_async_client_streaming_error_without_initialization(self): - """Test that async client streaming raises ValueError with helpful message when not initialized.""" + # initialize() should return None (historical behavior) and not cause errors + result = client.initialize() + assert result is None + + # Multiple calls should be safe + result2 = client.initialize() + assert result2 is None + + async def test_async_initialize_method_idempotent(self): + """Test that async initialize() method can be called multiple times safely.""" client = AsyncLlamaStackAsLibraryClient("nvidia") - with pytest.raises(ValueError) as exc_info: - stream = await client.chat.completions.create( - model="test", messages=[{"role": "user", "content": "test"}], stream=True - ) - await anext(stream) + # First initialization + result1 = await client.initialize() + assert result1 is True + assert client._is_initialized is True - error_msg = str(exc_info.value) - assert "Client not initialized" in error_msg - assert "Please call initialize() first" in error_msg + # Second initialization should be safe and return True + result2 = await client.initialize() + assert result2 is True + assert client._is_initialized is True - def test_route_impls_initialized_to_none(self): - """Test that route_impls is initialized to None to prevent AttributeError.""" - # Test sync client + def test_route_impls_automatically_set(self): + """Test that route_impls is automatically set during construction.""" + # Test sync client - should be auto-initialized sync_client = LlamaStackAsLibraryClient("nvidia") - assert sync_client.async_client.route_impls is None + assert sync_client.async_client.route_impls is not None - # Test async client directly - async_client = AsyncLlamaStackAsLibraryClient("nvidia") - assert async_client.route_impls is None + # Test that the async client is marked as initialized + assert sync_client.async_client._is_initialized is True