refactor(client): remove initialize() Method from LlamaStackAsLibrary

Currently client.initialize() had to be invoked by user.
To improve dev experience and to avoid runtime errors, this PR init LlamaStackAsLibrary implicitly upon using the client.
It prevents also multiple init of the same client, while maintaining backward ccompatibility.

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-31 14:18:45 +02:00
parent 46ff302d87
commit 3778a4c3e6
5 changed files with 76 additions and 87 deletions

View file

@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient(
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here. # provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]}, provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
) )
client.initialize()
``` ```
This will parse your config and set up any inline implementations and remote clients needed for your implementation. This will parse your config and set up any inline implementations and remote clients needed for your implementation.
@ -32,5 +31,4 @@ If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/
```python ```python
client = LlamaStackAsLibraryClient(config_path) client = LlamaStackAsLibraryClient(config_path)
client.initialize()
``` ```

View file

@ -145,7 +145,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
): ):
super().__init__() super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient( self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_distro_name, custom_provider_registry, provider_data config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
) )
self.pool_executor = ThreadPoolExecutor(max_workers=4) self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal self.skip_logger_removal = skip_logger_removal
@ -153,31 +153,19 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
def initialize(self):
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not self.skip_logger_removal:
self._remove_root_logger_handlers()
# use a new event loop to avoid interfering with the main event loop # use a new event loop to avoid interfering with the main event loop
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
try: try:
return loop.run_until_complete(self.async_client.initialize()) loop.run_until_complete(self.async_client.initialize())
finally: finally:
asyncio.set_event_loop(None) asyncio.set_event_loop(None)
def _remove_root_logger_handlers(self): def initialize(self):
""" """
Remove all handlers from the root logger. Needed to avoid polluting the console with logs. Deprecated method for backward compatibility.
""" """
root_logger = logging.getLogger() pass
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
def request(self, *args, **kwargs): def request(self, *args, **kwargs):
loop = self.loop loop = self.loop
@ -215,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
config_path_or_distro_name: str, config_path_or_distro_name: str,
custom_provider_registry: ProviderRegistry | None = None, custom_provider_registry: ProviderRegistry | None = None,
provider_data: dict[str, Any] | None = None, provider_data: dict[str, Any] | None = None,
skip_logger_removal: bool = False,
): ):
super().__init__() super().__init__()
# when using the library client, we should not log to console since many # when using the library client, we should not log to console since many
@ -222,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console") os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not skip_logger_removal:
self._remove_root_logger_handlers()
if config_path_or_distro_name.endswith(".yaml"): if config_path_or_distro_name.endswith(".yaml"):
config_path = Path(config_path_or_distro_name) config_path = Path(config_path_or_distro_name)
if not config_path.exists(): if not config_path.exists():
@ -238,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.provider_data = provider_data self.provider_data = provider_data
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
def _remove_root_logger_handlers(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
"""
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
async def initialize(self) -> bool: async def initialize(self) -> bool:
"""
Initialize the async client. Can be called multiple times safely.
Returns:
bool: True if initialization was successful
"""
try: try:
self.route_impls = None self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry) self.impls = await construct_stack(self.config, self.custom_provider_registry)
@ -298,9 +311,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
stream=False, stream=False,
stream_cls=None, stream_cls=None,
): ):
if self.route_impls is None:
raise ValueError("Client not initialized. Please call initialize() first.")
# Create headers with provider data if available # Create headers with provider data if available
headers = options.headers or {} headers = options.headers or {}
if self.provider_data: if self.provider_data:

View file

@ -256,9 +256,7 @@ def instantiate_llama_stack_client(session):
provider_data=get_provider_data(), provider_data=get_provider_data(),
skip_logger_removal=True, skip_logger_removal=True,
) )
if not client.initialize(): # Client is automatically initialized during construction
raise RuntimeError("Initialization failed")
return client return client

View file

@ -127,8 +127,7 @@ def openai_client(base_url, api_key, provider):
raise ValueError(f"Invalid config for Llama Stack: {provider}, it must be of the form 'stack:<config>'") raise ValueError(f"Invalid config for Llama Stack: {provider}, it must be of the form 'stack:<config>'")
config = parts[1] config = parts[1]
client = LlamaStackAsLibraryClient(config, skip_logger_removal=True) client = LlamaStackAsLibraryClient(config, skip_logger_removal=True)
if not client.initialize(): # Client is automatically initialized during construction
raise RuntimeError("Initialization failed")
return client return client
return OpenAI( return OpenAI(

View file

@ -5,86 +5,70 @@
# the root directory of this source tree. # the root directory of this source tree.
""" """
Unit tests for LlamaStackAsLibraryClient initialization error handling. Unit tests for LlamaStackAsLibraryClient automatic initialization.
These tests ensure that users get proper error messages when they forget to call These tests ensure that the library client is automatically initialized
initialize() on the library client, preventing AttributeError regressions. and ready to use immediately after construction.
""" """
import pytest
from llama_stack.core.library_client import ( from llama_stack.core.library_client import (
AsyncLlamaStackAsLibraryClient, AsyncLlamaStackAsLibraryClient,
LlamaStackAsLibraryClient, LlamaStackAsLibraryClient,
) )
class TestLlamaStackAsLibraryClientInitialization: class TestLlamaStackAsLibraryClientAutoInitialization:
"""Test proper error handling for uninitialized library clients.""" """Test automatic initialization of library clients."""
@pytest.mark.parametrize( def test_sync_client_auto_initialization(self):
"api_call", """Test that sync client is automatically initialized after construction."""
[
lambda client: client.models.list(),
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
lambda client: next(
client.chat.completions.create(
model="test", messages=[{"role": "user", "content": "test"}], stream=True
)
),
],
ids=["models.list", "chat.completions.create", "chat.completions.create_stream"],
)
def test_sync_client_proper_error_without_initialization(self, api_call):
"""Test that sync client raises ValueError with helpful message when not initialized."""
client = LlamaStackAsLibraryClient("nvidia") client = LlamaStackAsLibraryClient("nvidia")
with pytest.raises(ValueError) as exc_info: # Client should be automatically initialized
api_call(client) assert client.async_client._is_initialized is True
assert client.async_client.route_impls is not None
error_msg = str(exc_info.value) async def test_async_client_auto_initialization(self):
assert "Client not initialized" in error_msg """Test that async client can be initialized and works properly."""
assert "Please call initialize() first" in error_msg
@pytest.mark.parametrize(
"api_call",
[
lambda client: client.models.list(),
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
],
ids=["models.list", "chat.completions.create"],
)
async def test_async_client_proper_error_without_initialization(self, api_call):
"""Test that async client raises ValueError with helpful message when not initialized."""
client = AsyncLlamaStackAsLibraryClient("nvidia") client = AsyncLlamaStackAsLibraryClient("nvidia")
with pytest.raises(ValueError) as exc_info: # Initialize the client
await api_call(client) result = await client.initialize()
assert result is True
assert client._is_initialized is True
assert client.route_impls is not None
error_msg = str(exc_info.value) def test_initialize_method_backward_compatibility(self):
assert "Client not initialized" in error_msg """Test that initialize() method still works for backward compatibility."""
assert "Please call initialize() first" in error_msg client = LlamaStackAsLibraryClient("nvidia")
async def test_async_client_streaming_error_without_initialization(self): # initialize() should return None (historical behavior) and not cause errors
"""Test that async client streaming raises ValueError with helpful message when not initialized.""" result = client.initialize()
assert result is None
# Multiple calls should be safe
result2 = client.initialize()
assert result2 is None
async def test_async_initialize_method_idempotent(self):
"""Test that async initialize() method can be called multiple times safely."""
client = AsyncLlamaStackAsLibraryClient("nvidia") client = AsyncLlamaStackAsLibraryClient("nvidia")
with pytest.raises(ValueError) as exc_info: # First initialization
stream = await client.chat.completions.create( result1 = await client.initialize()
model="test", messages=[{"role": "user", "content": "test"}], stream=True assert result1 is True
) assert client._is_initialized is True
await anext(stream)
error_msg = str(exc_info.value) # Second initialization should be safe and return True
assert "Client not initialized" in error_msg result2 = await client.initialize()
assert "Please call initialize() first" in error_msg assert result2 is True
assert client._is_initialized is True
def test_route_impls_initialized_to_none(self): def test_route_impls_automatically_set(self):
"""Test that route_impls is initialized to None to prevent AttributeError.""" """Test that route_impls is automatically set during construction."""
# Test sync client # Test sync client - should be auto-initialized
sync_client = LlamaStackAsLibraryClient("nvidia") sync_client = LlamaStackAsLibraryClient("nvidia")
assert sync_client.async_client.route_impls is None assert sync_client.async_client.route_impls is not None
# Test async client directly # Test that the async client is marked as initialized
async_client = AsyncLlamaStackAsLibraryClient("nvidia") assert sync_client.async_client._is_initialized is True
assert async_client.route_impls is None