mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 00:47:15 +00:00
refactor(client): remove initialize() Method from LlamaStackAsLibrary
Currently client.initialize() had to be invoked by user. To improve dev experience and to avoid runtime errors, this PR init LlamaStackAsLibrary implicitly upon using the client. It prevents also multiple init of the same client, while maintaining backward ccompatibility. Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
ac25e35124
commit
c54278f3d7
5 changed files with 76 additions and 87 deletions
|
|
@ -146,7 +146,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
):
|
||||
super().__init__()
|
||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||
config_path_or_distro_name, custom_provider_registry, provider_data
|
||||
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
|
||||
)
|
||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||
self.skip_logger_removal = skip_logger_removal
|
||||
|
|
@ -154,31 +154,19 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
|
||||
self.loop = asyncio.new_event_loop()
|
||||
|
||||
def initialize(self):
|
||||
if in_notebook():
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
if not self.skip_logger_removal:
|
||||
self._remove_root_logger_handlers()
|
||||
|
||||
# use a new event loop to avoid interfering with the main event loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(self.async_client.initialize())
|
||||
loop.run_until_complete(self.async_client.initialize())
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
def _remove_root_logger_handlers(self):
|
||||
def initialize(self):
|
||||
"""
|
||||
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
||||
Deprecated method for backward compatibility.
|
||||
"""
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||
pass
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
loop = self.loop
|
||||
|
|
@ -216,6 +204,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
config_path_or_distro_name: str,
|
||||
custom_provider_registry: ProviderRegistry | None = None,
|
||||
provider_data: dict[str, Any] | None = None,
|
||||
skip_logger_removal: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# when using the library client, we should not log to console since many
|
||||
|
|
@ -223,6 +212,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
||||
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
||||
|
||||
if in_notebook():
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
if not skip_logger_removal:
|
||||
self._remove_root_logger_handlers()
|
||||
|
||||
if config_path_or_distro_name.endswith(".yaml"):
|
||||
config_path = Path(config_path_or_distro_name)
|
||||
if not config_path.exists():
|
||||
|
|
@ -239,7 +235,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
self.provider_data = provider_data
|
||||
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
|
||||
|
||||
def _remove_root_logger_handlers(self):
|
||||
"""
|
||||
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
||||
"""
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""
|
||||
Initialize the async client. Can be called multiple times safely.
|
||||
|
||||
Returns:
|
||||
bool: True if initialization was successful
|
||||
"""
|
||||
|
||||
try:
|
||||
self.route_impls = None
|
||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||
|
|
@ -299,9 +312,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
stream=False,
|
||||
stream_cls=None,
|
||||
):
|
||||
if self.route_impls is None:
|
||||
raise ValueError("Client not initialized. Please call initialize() first.")
|
||||
|
||||
# Create headers with provider data if available
|
||||
headers = options.headers or {}
|
||||
if self.provider_data:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue