mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
fix(library_client): improve initialization error handling and prevent AttributeError (#2944)
# What does this PR do? - Initialize route_impls to None in constructor to prevent AttributeError - Consolidate initialization checks to single point in request() method - Improve error message to be more helpful ("Please call initialize() first") - Add comprehensive test suite to prevent regressions The library client now has better error handling when users forget to call initialize(), showing a clear ValueError instead of confusing AttributeError. All initialization validation is now centralized in the request() method, with internal methods (_call_non_streaming, _call_streaming, _convert_body) relying on this single check for cleaner, more maintainable code. closes #2943 ## Test Plan `./scripts/unit-tests.sh`
This commit is contained in:
parent
9b69b6ac05
commit
b69bafba30
2 changed files with 97 additions and 12 deletions
|
@ -39,7 +39,7 @@ from llama_stack.distribution.request_headers import (
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry
|
from llama_stack.distribution.resolver import ProviderRegistry
|
||||||
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
|
from llama_stack.distribution.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
construct_stack,
|
construct_stack,
|
||||||
get_stack_run_config_from_template,
|
get_stack_run_config_from_template,
|
||||||
|
@ -236,6 +236,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.custom_provider_registry = custom_provider_registry
|
self.custom_provider_registry = custom_provider_registry
|
||||||
self.provider_data = provider_data
|
self.provider_data = provider_data
|
||||||
|
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
|
||||||
|
|
||||||
async def initialize(self) -> bool:
|
async def initialize(self) -> bool:
|
||||||
try:
|
try:
|
||||||
|
@ -297,8 +298,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
stream=False,
|
stream=False,
|
||||||
stream_cls=None,
|
stream_cls=None,
|
||||||
):
|
):
|
||||||
if not self.route_impls:
|
if self.route_impls is None:
|
||||||
raise ValueError("Client not initialized")
|
raise ValueError("Client not initialized. Please call initialize() first.")
|
||||||
|
|
||||||
# Create headers with provider data if available
|
# Create headers with provider data if available
|
||||||
headers = options.headers or {}
|
headers = options.headers or {}
|
||||||
|
@ -353,9 +354,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
cast_to: Any,
|
cast_to: Any,
|
||||||
options: Any,
|
options: Any,
|
||||||
):
|
):
|
||||||
if self.route_impls is None:
|
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
||||||
raise ValueError("Client not initialized")
|
|
||||||
|
|
||||||
path = options.url
|
path = options.url
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
|
@ -412,9 +411,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
options: Any,
|
options: Any,
|
||||||
stream_cls: Any,
|
stream_cls: Any,
|
||||||
):
|
):
|
||||||
if self.route_impls is None:
|
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
||||||
raise ValueError("Client not initialized")
|
|
||||||
|
|
||||||
path = options.url
|
path = options.url
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
|
@ -474,9 +471,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
if not body:
|
if not body:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
if self.route_impls is None:
|
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
||||||
raise ValueError("Client not initialized")
|
|
||||||
|
|
||||||
exclude_params = exclude_params or set()
|
exclude_params = exclude_params or set()
|
||||||
|
|
||||||
func, _, _, _ = find_matching_route(method, path, self.route_impls)
|
func, _, _, _ = find_matching_route(method, path, self.route_impls)
|
||||||
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Unit tests for LlamaStackAsLibraryClient initialization error handling.
|
||||||
|
|
||||||
|
These tests ensure that users get proper error messages when they forget to call
|
||||||
|
initialize() on the library client, preventing AttributeError regressions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.distribution.library_client import (
|
||||||
|
AsyncLlamaStackAsLibraryClient,
|
||||||
|
LlamaStackAsLibraryClient,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLlamaStackAsLibraryClientInitialization:
|
||||||
|
"""Test proper error handling for uninitialized library clients."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"api_call",
|
||||||
|
[
|
||||||
|
lambda client: client.models.list(),
|
||||||
|
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
|
||||||
|
lambda client: next(
|
||||||
|
client.chat.completions.create(
|
||||||
|
model="test", messages=[{"role": "user", "content": "test"}], stream=True
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
ids=["models.list", "chat.completions.create", "chat.completions.create_stream"],
|
||||||
|
)
|
||||||
|
def test_sync_client_proper_error_without_initialization(self, api_call):
|
||||||
|
"""Test that sync client raises ValueError with helpful message when not initialized."""
|
||||||
|
client = LlamaStackAsLibraryClient("nvidia")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
api_call(client)
|
||||||
|
|
||||||
|
error_msg = str(exc_info.value)
|
||||||
|
assert "Client not initialized" in error_msg
|
||||||
|
assert "Please call initialize() first" in error_msg
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"api_call",
|
||||||
|
[
|
||||||
|
lambda client: client.models.list(),
|
||||||
|
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
|
||||||
|
],
|
||||||
|
ids=["models.list", "chat.completions.create"],
|
||||||
|
)
|
||||||
|
async def test_async_client_proper_error_without_initialization(self, api_call):
|
||||||
|
"""Test that async client raises ValueError with helpful message when not initialized."""
|
||||||
|
client = AsyncLlamaStackAsLibraryClient("nvidia")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
await api_call(client)
|
||||||
|
|
||||||
|
error_msg = str(exc_info.value)
|
||||||
|
assert "Client not initialized" in error_msg
|
||||||
|
assert "Please call initialize() first" in error_msg
|
||||||
|
|
||||||
|
async def test_async_client_streaming_error_without_initialization(self):
|
||||||
|
"""Test that async client streaming raises ValueError with helpful message when not initialized."""
|
||||||
|
client = AsyncLlamaStackAsLibraryClient("nvidia")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
stream = await client.chat.completions.create(
|
||||||
|
model="test", messages=[{"role": "user", "content": "test"}], stream=True
|
||||||
|
)
|
||||||
|
await anext(stream)
|
||||||
|
|
||||||
|
error_msg = str(exc_info.value)
|
||||||
|
assert "Client not initialized" in error_msg
|
||||||
|
assert "Please call initialize() first" in error_msg
|
||||||
|
|
||||||
|
def test_route_impls_initialized_to_none(self):
|
||||||
|
"""Test that route_impls is initialized to None to prevent AttributeError."""
|
||||||
|
# Test sync client
|
||||||
|
sync_client = LlamaStackAsLibraryClient("nvidia")
|
||||||
|
assert sync_client.async_client.route_impls is None
|
||||||
|
|
||||||
|
# Test async client directly
|
||||||
|
async_client = AsyncLlamaStackAsLibraryClient("nvidia")
|
||||||
|
assert async_client.route_impls is None
|
Loading…
Add table
Add a link
Reference in a new issue