feat: add a configurable category-based logger (#1352)

A self-respecting server needs good observability which starts with
configurable logging. Llama Stack had little until now. This PR adds a
`logcat` facility towards that. Callsites look like:

```python
logcat.debug("inference", f"params to ollama: {params}")
```

- the first parameter is a category. there is a static list of
categories in `llama_stack/logcat.py`
- each category can be associated with a log-level which can be
configured via the `LLAMA_STACK_LOGGING` env var.
- a value `LLAMA_STACK_LOGGING=inference=debug;server=info"` does the
obvious thing. there is a special key called `all` which is an alias for
all categories

## Test Plan

Ran with `LLAMA_STACK_LOGGING="all=debug" llama stack run fireworks` and
saw the following:


![image](https://github.com/user-attachments/assets/d24b95ab-3941-426c-9ea0-a4c62542e6f0)

Hit it with a client-sdk test case and saw this:


![image](https://github.com/user-attachments/assets/3fee8c6c-986e-4125-a09c-f5dc019682e2)
This commit is contained in:
Ashwin Bharambe 2025-03-02 18:51:14 -08:00 committed by GitHub
parent a9a7b11326
commit 754feba61f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 409 additions and 47 deletions

View file

@ -26,9 +26,9 @@ from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack import logcat
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import set_request_provider_data
@ -55,7 +55,7 @@ from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
logger = logging.getLogger(__name__)
logcat.init()
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -142,23 +142,23 @@ def handle_signal(app, signum, _) -> None:
not block the current execution.
"""
signame = signal.Signals(signum).name
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...")
async def shutdown():
try:
# Gracefully shut down implementations
for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name)
logcat.info("server", f"Shutting down {impl_name}")
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning("No shutdown method for %s", impl_name)
logcat.warning("server", f"No shutdown method for {impl_name}")
except asyncio.TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
logcat.exception("server", f"Shutdown timeout for {impl_name}")
except Exception as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
logcat.exception("server", f"Failed to shutdown {impl_name}: {e}")
# Gather all running tasks
loop = asyncio.get_running_loop()
@ -172,7 +172,7 @@ def handle_signal(app, signum, _) -> None:
try:
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
except asyncio.TimeoutError:
logger.exception("Timeout while waiting for tasks to finish")
logcat.exception("server", "Timeout while waiting for tasks to finish")
except asyncio.CancelledError:
pass
finally:
@ -184,9 +184,9 @@ def handle_signal(app, signum, _) -> None:
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting up")
logcat.info("server", "Starting up")
yield
logger.info("Shutting down")
logcat.info("server", "Shutting down")
for impl in app.__llama_stack_impls__.values():
await impl.shutdown()
@ -209,10 +209,10 @@ async def sse_generator(event_gen):
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
logcat.info("server", "Generator cancelled")
await event_gen.aclose()
except Exception as e:
traceback.print_exception(e)
logcat.exception("server", "Error in sse_generator")
yield create_sse_event(
{
"error": {
@ -234,7 +234,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
value = func(**kwargs)
return await maybe_await(value)
except Exception as e:
traceback.print_exception(e)
logcat.exception("server", f"Error in {func.__name__}")
raise translate_exception(e) from e
sig = inspect.signature(func)
@ -313,6 +313,8 @@ class ClientVersionMiddleware:
def main():
logcat.init()
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument(
@ -352,10 +354,10 @@ def main():
for env_pair in args.env:
try:
key, value = validate_env_pair(env_pair)
logger.info(f"Setting CLI environment variable {key} => {value}")
logcat.info("server", f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logger.error(f"Error: {str(e)}")
logcat.error("server", f"Error: {str(e)}")
sys.exit(1)
if args.yaml_config:
@ -363,12 +365,12 @@ def main():
config_file = Path(args.yaml_config)
if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist")
logger.info(f"Using config file: {config_file}")
logcat.info("server", f"Using config file: {config_file}")
elif args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
logger.info(f"Using template {args.template} config file: {config_file}")
logcat.info("server", f"Using template {args.template} config file: {config_file}")
else:
raise ValueError("Either --yaml-config or --template must be provided")
@ -376,9 +378,10 @@ def main():
config = replace_env_vars(yaml.safe_load(fp))
config = StackRunConfig(**config)
logger.info("Run configuration:")
logcat.info("server", "Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump())
logger.info(yaml.dump(safe_config, indent=2))
for log_line in yaml.dump(safe_config, indent=2).split("\n"):
logcat.info("server", log_line)
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
@ -388,7 +391,7 @@ def main():
try:
impls = asyncio.run(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
logcat.error("server", f"Error: {str(e)}")
sys.exit(1)
if Api.telemetry in impls:
@ -433,11 +436,8 @@ def main():
)
)
logger.info(f"Serving API {api_str}")
for endpoint in endpoints:
cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
logcat.debug("server", f"Serving API {api_str}")
print("")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
@ -463,10 +463,10 @@ def main():
"ssl_keyfile": keyfile,
"ssl_certfile": certfile,
}
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
logger.info(f"Listening on {listen_host}:{port}")
logcat.info("server", f"Listening on {listen_host}:{port}")
uvicorn_config = {
"app": app,