llama_toolchain -> llama_stack

This commit is contained in:
Ashwin Bharambe 2024-09-16 17:21:08 -07:00
parent f372355409
commit 2cf731faea
175 changed files with 300 additions and 279 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .api import * # noqa: F401 F403

View file

@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@json_schema_type
class SpanStatus(Enum):
OK = "ok"
ERROR = "error"
@json_schema_type
class Span(BaseModel):
span_id: str
trace_id: str
parent_span_id: Optional[str] = None
name: str
start_time: datetime
end_time: Optional[datetime] = None
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
@json_schema_type
class Trace(BaseModel):
trace_id: str
root_span_id: str
start_time: datetime
end_time: Optional[datetime] = None
@json_schema_type
class EventType(Enum):
UNSTRUCTURED_LOG = "unstructured_log"
STRUCTURED_LOG = "structured_log"
METRIC = "metric"
@json_schema_type
class LogSeverity(Enum):
VERBOSE = "verbose"
DEBUG = "debug"
INFO = "info"
WARN = "warn"
ERROR = "error"
CRITICAL = "critical"
class EventCommon(BaseModel):
trace_id: str
span_id: str
timestamp: datetime
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
@json_schema_type
class UnstructuredLogEvent(EventCommon):
type: Literal[EventType.UNSTRUCTURED_LOG.value] = EventType.UNSTRUCTURED_LOG.value
message: str
severity: LogSeverity
@json_schema_type
class MetricEvent(EventCommon):
type: Literal[EventType.METRIC.value] = EventType.METRIC.value
metric: str # this would be an enum
value: Union[int, float]
unit: str
@json_schema_type
class StructuredLogType(Enum):
SPAN_START = "span_start"
SPAN_END = "span_end"
@json_schema_type
class SpanStartPayload(BaseModel):
type: Literal[StructuredLogType.SPAN_START.value] = (
StructuredLogType.SPAN_START.value
)
name: str
parent_span_id: Optional[str] = None
@json_schema_type
class SpanEndPayload(BaseModel):
type: Literal[StructuredLogType.SPAN_END.value] = StructuredLogType.SPAN_END.value
status: SpanStatus
StructuredLogPayload = Annotated[
Union[
SpanStartPayload,
SpanEndPayload,
],
Field(discriminator="type"),
]
@json_schema_type
class StructuredLogEvent(EventCommon):
type: Literal[EventType.STRUCTURED_LOG.value] = EventType.STRUCTURED_LOG.value
payload: StructuredLogPayload
Event = Annotated[
Union[
UnstructuredLogEvent,
MetricEvent,
StructuredLogEvent,
],
Field(discriminator="type"),
]
class Telemetry(Protocol):
@webmethod(route="/telemetry/log_event")
async def log_event(self, event: Event) -> None: ...
@webmethod(route="/telemetry/get_trace", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ...

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import ConsoleConfig
async def get_provider_impl(config: ConsoleConfig, _deps):
from .console import ConsoleTelemetryImpl
impl = ConsoleTelemetryImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@json_schema_type
class ConsoleConfig(BaseModel): ...

View file

@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_stack.telemetry.api import * # noqa: F403
from .config import ConsoleConfig
class ConsoleTelemetryImpl(Telemetry):
def __init__(self, config: ConsoleConfig) -> None:
self.config = config
self.spans = {}
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def log_event(self, event: Event):
if (
isinstance(event, StructuredLogEvent)
and event.payload.type == StructuredLogType.SPAN_START.value
):
self.spans[event.span_id] = event.payload
names = []
span_id = event.span_id
while True:
span_payload = self.spans.get(span_id)
if not span_payload:
break
names = [span_payload.name] + names
span_id = span_payload.parent_span_id
span_name = ".".join(names) if names else None
formatted = format_event(event, span_name)
if formatted:
print(formatted)
async def get_trace(self, trace_id: str) -> Trace:
raise NotImplementedError()
COLORS = {
"reset": "\033[0m",
"bold": "\033[1m",
"dim": "\033[2m",
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
}
SEVERITY_COLORS = {
LogSeverity.VERBOSE: COLORS["dim"] + COLORS["white"],
LogSeverity.DEBUG: COLORS["cyan"],
LogSeverity.INFO: COLORS["green"],
LogSeverity.WARN: COLORS["yellow"],
LogSeverity.ERROR: COLORS["red"],
LogSeverity.CRITICAL: COLORS["bold"] + COLORS["red"],
}
def format_event(event: Event, span_name: str) -> Optional[str]:
timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3]
span = ""
if span_name:
span = f"{COLORS['magenta']}[{span_name}]{COLORS['reset']} "
if isinstance(event, UnstructuredLogEvent):
severity_color = SEVERITY_COLORS.get(event.severity, COLORS["reset"])
return (
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{severity_color}[{event.severity.name}]{COLORS['reset']} "
f"{span}"
f"{event.message}"
)
elif isinstance(event, StructuredLogEvent):
return None
return f"Unknown event type: {event}"

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.core.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.telemetry,
provider_id="console",
pip_packages=[],
module="llama_stack.telemetry.console",
config_class="llama_stack.telemetry.console.ConsoleConfig",
),
]

View file

@ -0,0 +1,236 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64
import logging
import queue
import threading
import uuid
from datetime import datetime
from functools import wraps
from typing import Any, Dict, List
from llama_stack.telemetry.api import * # noqa: F403
def generate_short_uuid(len: int = 12):
full_uuid = uuid.uuid4()
uuid_bytes = full_uuid.bytes
encoded = base64.urlsafe_b64encode(uuid_bytes)
return encoded.rstrip(b"=").decode("ascii")[:len]
CURRENT_TRACE_CONTEXT = None
BACKGROUND_LOGGER = None
class BackgroundLogger:
def __init__(self, api: Telemetry, capacity: int = 1000):
self.api = api
self.log_queue = queue.Queue(maxsize=capacity)
self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)
self.worker_thread.start()
def log_event(self, event):
try:
self.log_queue.put_nowait(event)
except queue.Full:
print("Log queue is full, dropping event")
def _process_logs(self):
while True:
try:
event = self.log_queue.get()
# figure out how to use a thread's native loop
asyncio.run(self.api.log_event(event))
except Exception:
import traceback
traceback.print_exc()
print("Error processing log event")
finally:
self.log_queue.task_done()
def __del__(self):
self.log_queue.join()
class TraceContext:
spans: List[Span] = []
def __init__(self, logger: BackgroundLogger, trace_id: str):
self.logger = logger
self.trace_id = trace_id
def push_span(self, name: str, attributes: Dict[str, Any] = None):
current_span = self.get_current_span()
span = Span(
span_id=generate_short_uuid(),
trace_id=self.trace_id,
name=name,
start_time=datetime.now(),
parent_span_id=current_span.span_id if current_span else None,
attributes=attributes,
)
self.logger.log_event(
StructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=span.start_time,
attributes=span.attributes,
payload=SpanStartPayload(
name=span.name,
parent_span_id=span.parent_span_id,
),
)
)
self.spans.append(span)
def pop_span(self, status: SpanStatus = SpanStatus.OK):
span = self.spans.pop()
if span is not None:
self.logger.log_event(
StructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=span.start_time,
attributes=span.attributes,
payload=SpanEndPayload(
status=status,
),
)
)
def get_current_span(self):
return self.spans[-1] if self.spans else None
def setup_logger(api: Telemetry, level: int = logging.INFO):
global BACKGROUND_LOGGER
BACKGROUND_LOGGER = BackgroundLogger(api)
logger = logging.getLogger()
logger.setLevel(level)
logger.addHandler(TelemetryHandler())
async def start_trace(name: str, attributes: Dict[str, Any] = None):
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
print("No Telemetry implementation set. Skipping trace initialization...")
return
trace_id = generate_short_uuid()
context = TraceContext(BACKGROUND_LOGGER, trace_id)
context.push_span(name, {"__root__": True, **(attributes or {})})
CURRENT_TRACE_CONTEXT = context
async def end_trace(status: SpanStatus = SpanStatus.OK):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context is None:
return
context.pop_span(status)
CURRENT_TRACE_CONTEXT = None
def severity(levelname: str) -> LogSeverity:
if levelname == "DEBUG":
return LogSeverity.DEBUG
elif levelname == "INFO":
return LogSeverity.INFO
elif levelname == "WARNING":
return LogSeverity.WARNING
elif levelname == "ERROR":
return LogSeverity.ERROR
elif levelname == "CRITICAL":
return LogSeverity.CRITICAL
else:
raise ValueError(f"Unknown log level: {levelname}")
# TODO: ideally, the actual emitting should be done inside a separate daemon
# process completely isolated from the server
class TelemetryHandler(logging.Handler):
def emit(self, record: logging.LogRecord):
# horrendous hack to avoid logging from asyncio and getting into an infinite loop
if record.module in ("asyncio", "selector_events"):
return
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
raise RuntimeError("Telemetry API not initialized")
context = CURRENT_TRACE_CONTEXT
if context is None:
return
span = context.get_current_span()
if span is None:
return
BACKGROUND_LOGGER.log_event(
UnstructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=datetime.now(),
message=self.format(record),
severity=severity(record.levelname),
)
)
def close(self):
pass
def span(name: str, attributes: Dict[str, Any] = None):
def decorator(func):
@wraps(func)
def sync_wrapper(*args, **kwargs):
try:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(name, attributes)
result = func(*args, **kwargs)
finally:
context.pop_span()
return result
@wraps(func)
async def async_wrapper(*args, **kwargs):
try:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(name, attributes)
result = await func(*args, **kwargs)
finally:
context.pop_span()
return result
@wraps(func)
def wrapper(*args, **kwargs):
if asyncio.iscoroutinefunction(func):
return async_wrapper(*args, **kwargs)
else:
return sync_wrapper(*args, **kwargs)
return wrapper
return decorator