mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-22 17:53:55 +00:00
feat: Remove initialize() Method from LlamaStackAsLibrary (#2979)
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> This PR removes `init()` from `LlamaStackAsLibrary` Currently client.initialize() had to be invoked by user. To improve dev experience and to avoid runtime errors, this PR init LlamaStackAsLibrary implicitly upon using the client. It prevents also multiple init of the same client, while maintaining backward ccompatibility. This PR does the following - Automatic Initialization: Constructor calls initialize_impl() automatically. - Client is fully initialized after __init__ completes. - Prevents consecutive initialization after the client has been successfully initialized. - initialize() method still exists but is now a no-op. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> fixes https://github.com/meta-llama/llama-stack/issues/2946 --------- Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
ac25e35124
commit
1790fc0f25
5 changed files with 128 additions and 88 deletions
|
@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient(
|
||||||
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
||||||
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
||||||
)
|
)
|
||||||
client.initialize()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
This will parse your config and set up any inline implementations and remote clients needed for your implementation.
|
This will parse your config and set up any inline implementations and remote clients needed for your implementation.
|
||||||
|
@ -32,5 +31,4 @@ If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/
|
||||||
|
|
||||||
```python
|
```python
|
||||||
client = LlamaStackAsLibraryClient(config_path)
|
client = LlamaStackAsLibraryClient(config_path)
|
||||||
client.initialize()
|
|
||||||
```
|
```
|
||||||
|
|
|
@ -146,39 +146,26 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||||
config_path_or_distro_name, custom_provider_registry, provider_data
|
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
|
||||||
)
|
)
|
||||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||||
self.skip_logger_removal = skip_logger_removal
|
|
||||||
self.provider_data = provider_data
|
self.provider_data = provider_data
|
||||||
|
|
||||||
self.loop = asyncio.new_event_loop()
|
self.loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
def initialize(self):
|
|
||||||
if in_notebook():
|
|
||||||
import nest_asyncio
|
|
||||||
|
|
||||||
nest_asyncio.apply()
|
|
||||||
if not self.skip_logger_removal:
|
|
||||||
self._remove_root_logger_handlers()
|
|
||||||
|
|
||||||
# use a new event loop to avoid interfering with the main event loop
|
# use a new event loop to avoid interfering with the main event loop
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
try:
|
try:
|
||||||
return loop.run_until_complete(self.async_client.initialize())
|
loop.run_until_complete(self.async_client.initialize())
|
||||||
finally:
|
finally:
|
||||||
asyncio.set_event_loop(None)
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
def _remove_root_logger_handlers(self):
|
def initialize(self):
|
||||||
"""
|
"""
|
||||||
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
Deprecated method for backward compatibility.
|
||||||
"""
|
"""
|
||||||
root_logger = logging.getLogger()
|
pass
|
||||||
|
|
||||||
for handler in root_logger.handlers[:]:
|
|
||||||
root_logger.removeHandler(handler)
|
|
||||||
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
|
||||||
|
|
||||||
def request(self, *args, **kwargs):
|
def request(self, *args, **kwargs):
|
||||||
loop = self.loop
|
loop = self.loop
|
||||||
|
@ -216,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
config_path_or_distro_name: str,
|
config_path_or_distro_name: str,
|
||||||
custom_provider_registry: ProviderRegistry | None = None,
|
custom_provider_registry: ProviderRegistry | None = None,
|
||||||
provider_data: dict[str, Any] | None = None,
|
provider_data: dict[str, Any] | None = None,
|
||||||
|
skip_logger_removal: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# when using the library client, we should not log to console since many
|
# when using the library client, we should not log to console since many
|
||||||
|
@ -223,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
||||||
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
||||||
|
|
||||||
|
if in_notebook():
|
||||||
|
import nest_asyncio
|
||||||
|
|
||||||
|
nest_asyncio.apply()
|
||||||
|
if not skip_logger_removal:
|
||||||
|
self._remove_root_logger_handlers()
|
||||||
|
|
||||||
if config_path_or_distro_name.endswith(".yaml"):
|
if config_path_or_distro_name.endswith(".yaml"):
|
||||||
config_path = Path(config_path_or_distro_name)
|
config_path = Path(config_path_or_distro_name)
|
||||||
if not config_path.exists():
|
if not config_path.exists():
|
||||||
|
@ -239,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
self.provider_data = provider_data
|
self.provider_data = provider_data
|
||||||
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
|
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
|
||||||
|
|
||||||
|
def _remove_root_logger_handlers(self):
|
||||||
|
"""
|
||||||
|
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
||||||
|
"""
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
|
||||||
|
for handler in root_logger.handlers[:]:
|
||||||
|
root_logger.removeHandler(handler)
|
||||||
|
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||||
|
|
||||||
async def initialize(self) -> bool:
|
async def initialize(self) -> bool:
|
||||||
|
"""
|
||||||
|
Initialize the async client.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if initialization was successful
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.route_impls = None
|
self.route_impls = None
|
||||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||||
|
|
|
@ -256,9 +256,6 @@ def instantiate_llama_stack_client(session):
|
||||||
provider_data=get_provider_data(),
|
provider_data=get_provider_data(),
|
||||||
skip_logger_removal=True,
|
skip_logger_removal=True,
|
||||||
)
|
)
|
||||||
if not client.initialize():
|
|
||||||
raise RuntimeError("Initialization failed")
|
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -113,8 +113,6 @@ def openai_client(base_url, api_key, provider):
|
||||||
raise ValueError(f"Invalid config for Llama Stack: {provider}, it must be of the form 'stack:<config>'")
|
raise ValueError(f"Invalid config for Llama Stack: {provider}, it must be of the form 'stack:<config>'")
|
||||||
config = parts[1]
|
config = parts[1]
|
||||||
client = LlamaStackAsLibraryClient(config, skip_logger_removal=True)
|
client = LlamaStackAsLibraryClient(config, skip_logger_removal=True)
|
||||||
if not client.initialize():
|
|
||||||
raise RuntimeError("Initialization failed")
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
return OpenAI(
|
return OpenAI(
|
||||||
|
|
|
@ -5,86 +5,121 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Unit tests for LlamaStackAsLibraryClient initialization error handling.
|
Unit tests for LlamaStackAsLibraryClient automatic initialization.
|
||||||
|
|
||||||
These tests ensure that users get proper error messages when they forget to call
|
These tests ensure that the library client is automatically initialized
|
||||||
initialize() on the library client, preventing AttributeError regressions.
|
and ready to use immediately after construction.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from llama_stack.core.library_client import (
|
from llama_stack.core.library_client import (
|
||||||
AsyncLlamaStackAsLibraryClient,
|
AsyncLlamaStackAsLibraryClient,
|
||||||
LlamaStackAsLibraryClient,
|
LlamaStackAsLibraryClient,
|
||||||
)
|
)
|
||||||
|
from llama_stack.core.server.routes import RouteImpls
|
||||||
|
|
||||||
|
|
||||||
class TestLlamaStackAsLibraryClientInitialization:
|
class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
"""Test proper error handling for uninitialized library clients."""
|
"""Test automatic initialization of library clients."""
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
def test_sync_client_auto_initialization(self, monkeypatch):
|
||||||
"api_call",
|
"""Test that sync client is automatically initialized after construction."""
|
||||||
[
|
# Mock the stack construction to avoid dependency issues
|
||||||
lambda client: client.models.list(),
|
mock_impls = {}
|
||||||
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
|
mock_route_impls = RouteImpls({})
|
||||||
lambda client: next(
|
|
||||||
client.chat.completions.create(
|
|
||||||
model="test", messages=[{"role": "user", "content": "test"}], stream=True
|
|
||||||
)
|
|
||||||
),
|
|
||||||
],
|
|
||||||
ids=["models.list", "chat.completions.create", "chat.completions.create_stream"],
|
|
||||||
)
|
|
||||||
def test_sync_client_proper_error_without_initialization(self, api_call):
|
|
||||||
"""Test that sync client raises ValueError with helpful message when not initialized."""
|
|
||||||
client = LlamaStackAsLibraryClient("nvidia")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
async def mock_construct_stack(config, custom_provider_registry):
|
||||||
api_call(client)
|
return mock_impls
|
||||||
|
|
||||||
error_msg = str(exc_info.value)
|
def mock_initialize_route_impls(impls):
|
||||||
assert "Client not initialized" in error_msg
|
return mock_route_impls
|
||||||
assert "Please call initialize() first" in error_msg
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||||
"api_call",
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
[
|
|
||||||
lambda client: client.models.list(),
|
|
||||||
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
|
|
||||||
],
|
|
||||||
ids=["models.list", "chat.completions.create"],
|
|
||||||
)
|
|
||||||
async def test_async_client_proper_error_without_initialization(self, api_call):
|
|
||||||
"""Test that async client raises ValueError with helpful message when not initialized."""
|
|
||||||
client = AsyncLlamaStackAsLibraryClient("nvidia")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
client = LlamaStackAsLibraryClient("ci-tests")
|
||||||
await api_call(client)
|
|
||||||
|
|
||||||
error_msg = str(exc_info.value)
|
assert client.async_client.route_impls is not None
|
||||||
assert "Client not initialized" in error_msg
|
|
||||||
assert "Please call initialize() first" in error_msg
|
|
||||||
|
|
||||||
async def test_async_client_streaming_error_without_initialization(self):
|
async def test_async_client_auto_initialization(self, monkeypatch):
|
||||||
"""Test that async client streaming raises ValueError with helpful message when not initialized."""
|
"""Test that async client can be initialized and works properly."""
|
||||||
client = AsyncLlamaStackAsLibraryClient("nvidia")
|
# Mock the stack construction to avoid dependency issues
|
||||||
|
mock_impls = {}
|
||||||
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
async def mock_construct_stack(config, custom_provider_registry):
|
||||||
stream = await client.chat.completions.create(
|
return mock_impls
|
||||||
model="test", messages=[{"role": "user", "content": "test"}], stream=True
|
|
||||||
)
|
|
||||||
await anext(stream)
|
|
||||||
|
|
||||||
error_msg = str(exc_info.value)
|
def mock_initialize_route_impls(impls):
|
||||||
assert "Client not initialized" in error_msg
|
return mock_route_impls
|
||||||
assert "Please call initialize() first" in error_msg
|
|
||||||
|
|
||||||
def test_route_impls_initialized_to_none(self):
|
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||||
"""Test that route_impls is initialized to None to prevent AttributeError."""
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
# Test sync client
|
|
||||||
sync_client = LlamaStackAsLibraryClient("nvidia")
|
|
||||||
assert sync_client.async_client.route_impls is None
|
|
||||||
|
|
||||||
# Test async client directly
|
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||||
async_client = AsyncLlamaStackAsLibraryClient("nvidia")
|
|
||||||
assert async_client.route_impls is None
|
# Initialize the client
|
||||||
|
result = await client.initialize()
|
||||||
|
assert result is True
|
||||||
|
assert client.route_impls is not None
|
||||||
|
|
||||||
|
def test_initialize_method_backward_compatibility(self, monkeypatch):
|
||||||
|
"""Test that initialize() method still works for backward compatibility."""
|
||||||
|
# Mock the stack construction to avoid dependency issues
|
||||||
|
mock_impls = {}
|
||||||
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
|
async def mock_construct_stack(config, custom_provider_registry):
|
||||||
|
return mock_impls
|
||||||
|
|
||||||
|
def mock_initialize_route_impls(impls):
|
||||||
|
return mock_route_impls
|
||||||
|
|
||||||
|
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||||
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
|
client = LlamaStackAsLibraryClient("ci-tests")
|
||||||
|
|
||||||
|
result = client.initialize()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
result2 = client.initialize()
|
||||||
|
assert result2 is None
|
||||||
|
|
||||||
|
async def test_async_initialize_method_idempotent(self, monkeypatch):
|
||||||
|
"""Test that async initialize() method can be called multiple times safely."""
|
||||||
|
mock_impls = {}
|
||||||
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
|
async def mock_construct_stack(config, custom_provider_registry):
|
||||||
|
return mock_impls
|
||||||
|
|
||||||
|
def mock_initialize_route_impls(impls):
|
||||||
|
return mock_route_impls
|
||||||
|
|
||||||
|
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||||
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
|
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||||
|
|
||||||
|
result1 = await client.initialize()
|
||||||
|
assert result1 is True
|
||||||
|
|
||||||
|
result2 = await client.initialize()
|
||||||
|
assert result2 is True
|
||||||
|
|
||||||
|
def test_route_impls_automatically_set(self, monkeypatch):
|
||||||
|
"""Test that route_impls is automatically set during construction."""
|
||||||
|
mock_impls = {}
|
||||||
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
|
async def mock_construct_stack(config, custom_provider_registry):
|
||||||
|
return mock_impls
|
||||||
|
|
||||||
|
def mock_initialize_route_impls(impls):
|
||||||
|
return mock_route_impls
|
||||||
|
|
||||||
|
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||||
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
|
sync_client = LlamaStackAsLibraryClient("ci-tests")
|
||||||
|
assert sync_client.async_client.route_impls is not None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue