mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
fix(library_client): improve initialization error handling and prevent AttributeError (#2944)
# What does this PR do? - Initialize route_impls to None in constructor to prevent AttributeError - Consolidate initialization checks to single point in request() method - Improve error message to be more helpful ("Please call initialize() first") - Add comprehensive test suite to prevent regressions The library client now has better error handling when users forget to call initialize(), showing a clear ValueError instead of confusing AttributeError. All initialization validation is now centralized in the request() method, with internal methods (_call_non_streaming, _call_streaming, _convert_body) relying on this single check for cleaner, more maintainable code. closes #2943 ## Test Plan `./scripts/unit-tests.sh`
This commit is contained in:
parent
9b69b6ac05
commit
b69bafba30
2 changed files with 97 additions and 12 deletions
|
@ -39,7 +39,7 @@ from llama_stack.distribution.request_headers import (
|
|||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.distribution.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
get_stack_run_config_from_template,
|
||||
|
@ -236,6 +236,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
self.config = config
|
||||
self.custom_provider_registry = custom_provider_registry
|
||||
self.provider_data = provider_data
|
||||
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
try:
|
||||
|
@ -297,8 +298,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
stream=False,
|
||||
stream_cls=None,
|
||||
):
|
||||
if not self.route_impls:
|
||||
raise ValueError("Client not initialized")
|
||||
if self.route_impls is None:
|
||||
raise ValueError("Client not initialized. Please call initialize() first.")
|
||||
|
||||
# Create headers with provider data if available
|
||||
headers = options.headers or {}
|
||||
|
@ -353,9 +354,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
cast_to: Any,
|
||||
options: Any,
|
||||
):
|
||||
if self.route_impls is None:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
||||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
|
@ -412,9 +411,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
options: Any,
|
||||
stream_cls: Any,
|
||||
):
|
||||
if self.route_impls is None:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
||||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
|
@ -474,9 +471,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
if not body:
|
||||
return {}
|
||||
|
||||
if self.route_impls is None:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
||||
exclude_params = exclude_params or set()
|
||||
|
||||
func, _, _, _ = find_matching_route(method, path, self.route_impls)
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Unit tests for LlamaStackAsLibraryClient initialization error handling.
|
||||
|
||||
These tests ensure that users get proper error messages when they forget to call
|
||||
initialize() on the library client, preventing AttributeError regressions.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.distribution.library_client import (
|
||||
AsyncLlamaStackAsLibraryClient,
|
||||
LlamaStackAsLibraryClient,
|
||||
)
|
||||
|
||||
|
||||
class TestLlamaStackAsLibraryClientInitialization:
|
||||
"""Test proper error handling for uninitialized library clients."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_call",
|
||||
[
|
||||
lambda client: client.models.list(),
|
||||
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
|
||||
lambda client: next(
|
||||
client.chat.completions.create(
|
||||
model="test", messages=[{"role": "user", "content": "test"}], stream=True
|
||||
)
|
||||
),
|
||||
],
|
||||
ids=["models.list", "chat.completions.create", "chat.completions.create_stream"],
|
||||
)
|
||||
def test_sync_client_proper_error_without_initialization(self, api_call):
|
||||
"""Test that sync client raises ValueError with helpful message when not initialized."""
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
api_call(client)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Client not initialized" in error_msg
|
||||
assert "Please call initialize() first" in error_msg
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_call",
|
||||
[
|
||||
lambda client: client.models.list(),
|
||||
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
|
||||
],
|
||||
ids=["models.list", "chat.completions.create"],
|
||||
)
|
||||
async def test_async_client_proper_error_without_initialization(self, api_call):
|
||||
"""Test that async client raises ValueError with helpful message when not initialized."""
|
||||
client = AsyncLlamaStackAsLibraryClient("nvidia")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await api_call(client)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Client not initialized" in error_msg
|
||||
assert "Please call initialize() first" in error_msg
|
||||
|
||||
async def test_async_client_streaming_error_without_initialization(self):
|
||||
"""Test that async client streaming raises ValueError with helpful message when not initialized."""
|
||||
client = AsyncLlamaStackAsLibraryClient("nvidia")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
stream = await client.chat.completions.create(
|
||||
model="test", messages=[{"role": "user", "content": "test"}], stream=True
|
||||
)
|
||||
await anext(stream)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Client not initialized" in error_msg
|
||||
assert "Please call initialize() first" in error_msg
|
||||
|
||||
def test_route_impls_initialized_to_none(self):
|
||||
"""Test that route_impls is initialized to None to prevent AttributeError."""
|
||||
# Test sync client
|
||||
sync_client = LlamaStackAsLibraryClient("nvidia")
|
||||
assert sync_client.async_client.route_impls is None
|
||||
|
||||
# Test async client directly
|
||||
async_client = AsyncLlamaStackAsLibraryClient("nvidia")
|
||||
assert async_client.route_impls is None
|
Loading…
Add table
Add a link
Reference in a new issue