mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
Merge bbe2a89adc
into 81ecaf6221
This commit is contained in:
commit
e88f15b93e
5 changed files with 130 additions and 90 deletions
|
@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient(
|
||||||
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
||||||
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
||||||
)
|
)
|
||||||
client.initialize()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
This will parse your config and set up any inline implementations and remote clients needed for your implementation.
|
This will parse your config and set up any inline implementations and remote clients needed for your implementation.
|
||||||
|
@ -32,5 +31,4 @@ If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/
|
||||||
|
|
||||||
```python
|
```python
|
||||||
client = LlamaStackAsLibraryClient(config_path)
|
client = LlamaStackAsLibraryClient(config_path)
|
||||||
client.initialize()
|
|
||||||
```
|
```
|
||||||
|
|
|
@ -145,7 +145,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||||
config_path_or_distro_name, custom_provider_registry, provider_data
|
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
|
||||||
)
|
)
|
||||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||||
self.skip_logger_removal = skip_logger_removal
|
self.skip_logger_removal = skip_logger_removal
|
||||||
|
@ -153,31 +153,19 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
|
|
||||||
self.loop = asyncio.new_event_loop()
|
self.loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
def initialize(self):
|
|
||||||
if in_notebook():
|
|
||||||
import nest_asyncio
|
|
||||||
|
|
||||||
nest_asyncio.apply()
|
|
||||||
if not self.skip_logger_removal:
|
|
||||||
self._remove_root_logger_handlers()
|
|
||||||
|
|
||||||
# use a new event loop to avoid interfering with the main event loop
|
# use a new event loop to avoid interfering with the main event loop
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
try:
|
try:
|
||||||
return loop.run_until_complete(self.async_client.initialize())
|
loop.run_until_complete(self.async_client.initialize())
|
||||||
finally:
|
finally:
|
||||||
asyncio.set_event_loop(None)
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
def _remove_root_logger_handlers(self):
|
def initialize(self):
|
||||||
"""
|
"""
|
||||||
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
Deprecated method for backward compatibility.
|
||||||
"""
|
"""
|
||||||
root_logger = logging.getLogger()
|
pass
|
||||||
|
|
||||||
for handler in root_logger.handlers[:]:
|
|
||||||
root_logger.removeHandler(handler)
|
|
||||||
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
|
||||||
|
|
||||||
def request(self, *args, **kwargs):
|
def request(self, *args, **kwargs):
|
||||||
loop = self.loop
|
loop = self.loop
|
||||||
|
@ -215,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
config_path_or_distro_name: str,
|
config_path_or_distro_name: str,
|
||||||
custom_provider_registry: ProviderRegistry | None = None,
|
custom_provider_registry: ProviderRegistry | None = None,
|
||||||
provider_data: dict[str, Any] | None = None,
|
provider_data: dict[str, Any] | None = None,
|
||||||
|
skip_logger_removal: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# when using the library client, we should not log to console since many
|
# when using the library client, we should not log to console since many
|
||||||
|
@ -222,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
||||||
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
||||||
|
|
||||||
|
if in_notebook():
|
||||||
|
import nest_asyncio
|
||||||
|
|
||||||
|
nest_asyncio.apply()
|
||||||
|
if not skip_logger_removal:
|
||||||
|
self._remove_root_logger_handlers()
|
||||||
|
|
||||||
if config_path_or_distro_name.endswith(".yaml"):
|
if config_path_or_distro_name.endswith(".yaml"):
|
||||||
config_path = Path(config_path_or_distro_name)
|
config_path = Path(config_path_or_distro_name)
|
||||||
if not config_path.exists():
|
if not config_path.exists():
|
||||||
|
@ -238,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
self.provider_data = provider_data
|
self.provider_data = provider_data
|
||||||
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
|
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
|
||||||
|
|
||||||
|
def _remove_root_logger_handlers(self):
|
||||||
|
"""
|
||||||
|
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
||||||
|
"""
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
|
||||||
|
for handler in root_logger.handlers[:]:
|
||||||
|
root_logger.removeHandler(handler)
|
||||||
|
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||||
|
|
||||||
async def initialize(self) -> bool:
|
async def initialize(self) -> bool:
|
||||||
|
"""
|
||||||
|
Initialize the async client. Can be called multiple times safely.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if initialization was successful
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.route_impls = None
|
self.route_impls = None
|
||||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||||
|
@ -298,9 +311,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
stream=False,
|
stream=False,
|
||||||
stream_cls=None,
|
stream_cls=None,
|
||||||
):
|
):
|
||||||
if self.route_impls is None:
|
|
||||||
raise ValueError("Client not initialized. Please call initialize() first.")
|
|
||||||
|
|
||||||
# Create headers with provider data if available
|
# Create headers with provider data if available
|
||||||
headers = options.headers or {}
|
headers = options.headers or {}
|
||||||
if self.provider_data:
|
if self.provider_data:
|
||||||
|
|
|
@ -256,9 +256,7 @@ def instantiate_llama_stack_client(session):
|
||||||
provider_data=get_provider_data(),
|
provider_data=get_provider_data(),
|
||||||
skip_logger_removal=True,
|
skip_logger_removal=True,
|
||||||
)
|
)
|
||||||
if not client.initialize():
|
# Client is automatically initialized during construction
|
||||||
raise RuntimeError("Initialization failed")
|
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -113,8 +113,7 @@ def openai_client(base_url, api_key, provider):
|
||||||
raise ValueError(f"Invalid config for Llama Stack: {provider}, it must be of the form 'stack:<config>'")
|
raise ValueError(f"Invalid config for Llama Stack: {provider}, it must be of the form 'stack:<config>'")
|
||||||
config = parts[1]
|
config = parts[1]
|
||||||
client = LlamaStackAsLibraryClient(config, skip_logger_removal=True)
|
client = LlamaStackAsLibraryClient(config, skip_logger_removal=True)
|
||||||
if not client.initialize():
|
# Client is automatically initialized during construction
|
||||||
raise RuntimeError("Initialization failed")
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
return OpenAI(
|
return OpenAI(
|
||||||
|
|
|
@ -5,86 +5,121 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Unit tests for LlamaStackAsLibraryClient initialization error handling.
|
Unit tests for LlamaStackAsLibraryClient automatic initialization.
|
||||||
|
|
||||||
These tests ensure that users get proper error messages when they forget to call
|
These tests ensure that the library client is automatically initialized
|
||||||
initialize() on the library client, preventing AttributeError regressions.
|
and ready to use immediately after construction.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from llama_stack.core.library_client import (
|
from llama_stack.core.library_client import (
|
||||||
AsyncLlamaStackAsLibraryClient,
|
AsyncLlamaStackAsLibraryClient,
|
||||||
LlamaStackAsLibraryClient,
|
LlamaStackAsLibraryClient,
|
||||||
)
|
)
|
||||||
|
from llama_stack.core.server.routes import RouteImpls
|
||||||
|
|
||||||
|
|
||||||
class TestLlamaStackAsLibraryClientInitialization:
|
class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
"""Test proper error handling for uninitialized library clients."""
|
"""Test automatic initialization of library clients."""
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
def test_sync_client_auto_initialization(self, monkeypatch):
|
||||||
"api_call",
|
"""Test that sync client is automatically initialized after construction."""
|
||||||
[
|
# Mock the stack construction to avoid dependency issues
|
||||||
lambda client: client.models.list(),
|
mock_impls = {}
|
||||||
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
|
mock_route_impls = RouteImpls({})
|
||||||
lambda client: next(
|
|
||||||
client.chat.completions.create(
|
|
||||||
model="test", messages=[{"role": "user", "content": "test"}], stream=True
|
|
||||||
)
|
|
||||||
),
|
|
||||||
],
|
|
||||||
ids=["models.list", "chat.completions.create", "chat.completions.create_stream"],
|
|
||||||
)
|
|
||||||
def test_sync_client_proper_error_without_initialization(self, api_call):
|
|
||||||
"""Test that sync client raises ValueError with helpful message when not initialized."""
|
|
||||||
client = LlamaStackAsLibraryClient("nvidia")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
async def mock_construct_stack(config, custom_provider_registry):
|
||||||
api_call(client)
|
return mock_impls
|
||||||
|
|
||||||
error_msg = str(exc_info.value)
|
def mock_initialize_route_impls(impls):
|
||||||
assert "Client not initialized" in error_msg
|
return mock_route_impls
|
||||||
assert "Please call initialize() first" in error_msg
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||||
"api_call",
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
[
|
|
||||||
lambda client: client.models.list(),
|
|
||||||
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
|
|
||||||
],
|
|
||||||
ids=["models.list", "chat.completions.create"],
|
|
||||||
)
|
|
||||||
async def test_async_client_proper_error_without_initialization(self, api_call):
|
|
||||||
"""Test that async client raises ValueError with helpful message when not initialized."""
|
|
||||||
client = AsyncLlamaStackAsLibraryClient("nvidia")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
client = LlamaStackAsLibraryClient("ci-tests")
|
||||||
await api_call(client)
|
|
||||||
|
|
||||||
error_msg = str(exc_info.value)
|
assert client.async_client.route_impls is not None
|
||||||
assert "Client not initialized" in error_msg
|
|
||||||
assert "Please call initialize() first" in error_msg
|
|
||||||
|
|
||||||
async def test_async_client_streaming_error_without_initialization(self):
|
async def test_async_client_auto_initialization(self, monkeypatch):
|
||||||
"""Test that async client streaming raises ValueError with helpful message when not initialized."""
|
"""Test that async client can be initialized and works properly."""
|
||||||
client = AsyncLlamaStackAsLibraryClient("nvidia")
|
# Mock the stack construction to avoid dependency issues
|
||||||
|
mock_impls = {}
|
||||||
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
async def mock_construct_stack(config, custom_provider_registry):
|
||||||
stream = await client.chat.completions.create(
|
return mock_impls
|
||||||
model="test", messages=[{"role": "user", "content": "test"}], stream=True
|
|
||||||
)
|
|
||||||
await anext(stream)
|
|
||||||
|
|
||||||
error_msg = str(exc_info.value)
|
def mock_initialize_route_impls(impls):
|
||||||
assert "Client not initialized" in error_msg
|
return mock_route_impls
|
||||||
assert "Please call initialize() first" in error_msg
|
|
||||||
|
|
||||||
def test_route_impls_initialized_to_none(self):
|
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||||
"""Test that route_impls is initialized to None to prevent AttributeError."""
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
# Test sync client
|
|
||||||
sync_client = LlamaStackAsLibraryClient("nvidia")
|
|
||||||
assert sync_client.async_client.route_impls is None
|
|
||||||
|
|
||||||
# Test async client directly
|
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||||
async_client = AsyncLlamaStackAsLibraryClient("nvidia")
|
|
||||||
assert async_client.route_impls is None
|
# Initialize the client
|
||||||
|
result = await client.initialize()
|
||||||
|
assert result is True
|
||||||
|
assert client.route_impls is not None
|
||||||
|
|
||||||
|
def test_initialize_method_backward_compatibility(self, monkeypatch):
|
||||||
|
"""Test that initialize() method still works for backward compatibility."""
|
||||||
|
# Mock the stack construction to avoid dependency issues
|
||||||
|
mock_impls = {}
|
||||||
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
|
async def mock_construct_stack(config, custom_provider_registry):
|
||||||
|
return mock_impls
|
||||||
|
|
||||||
|
def mock_initialize_route_impls(impls):
|
||||||
|
return mock_route_impls
|
||||||
|
|
||||||
|
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||||
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
|
client = LlamaStackAsLibraryClient("ci-tests")
|
||||||
|
|
||||||
|
result = client.initialize()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
result2 = client.initialize()
|
||||||
|
assert result2 is None
|
||||||
|
|
||||||
|
async def test_async_initialize_method_idempotent(self, monkeypatch):
|
||||||
|
"""Test that async initialize() method can be called multiple times safely."""
|
||||||
|
mock_impls = {}
|
||||||
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
|
async def mock_construct_stack(config, custom_provider_registry):
|
||||||
|
return mock_impls
|
||||||
|
|
||||||
|
def mock_initialize_route_impls(impls):
|
||||||
|
return mock_route_impls
|
||||||
|
|
||||||
|
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||||
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
|
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||||
|
|
||||||
|
result1 = await client.initialize()
|
||||||
|
assert result1 is True
|
||||||
|
|
||||||
|
result2 = await client.initialize()
|
||||||
|
assert result2 is True
|
||||||
|
|
||||||
|
def test_route_impls_automatically_set(self, monkeypatch):
|
||||||
|
"""Test that route_impls is automatically set during construction."""
|
||||||
|
mock_impls = {}
|
||||||
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
|
async def mock_construct_stack(config, custom_provider_registry):
|
||||||
|
return mock_impls
|
||||||
|
|
||||||
|
def mock_initialize_route_impls(impls):
|
||||||
|
return mock_route_impls
|
||||||
|
|
||||||
|
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||||
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
|
sync_client = LlamaStackAsLibraryClient("ci-tests")
|
||||||
|
assert sync_client.async_client.route_impls is not None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue