mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Simplified Telemetry API and tying it to logger (#57)
* Simplified Telemetry API and tying it to logger * small update which adds a METRIC type * move span events one level down into structured log events --------- Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
parent
1433aaf9f7
commit
191cd28831
15 changed files with 524 additions and 162 deletions
|
@ -9,7 +9,7 @@ from typing import List
|
|||
from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
def available_agentic_system_providers() -> List[ProviderSpec]:
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.agentic_system,
|
||||
|
|
|
@ -19,6 +19,7 @@ class Api(Enum):
|
|||
safety = "safety"
|
||||
agentic_system = "agentic_system"
|
||||
memory = "memory"
|
||||
telemetry = "telemetry"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -4,17 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Dict, List
|
||||
|
||||
from llama_toolchain.agentic_system.api import AgenticSystem
|
||||
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
|
||||
from llama_toolchain.inference.api import Inference
|
||||
from llama_toolchain.inference.providers import available_inference_providers
|
||||
from llama_toolchain.memory.api import Memory
|
||||
from llama_toolchain.memory.providers import available_memory_providers
|
||||
from llama_toolchain.safety.api import Safety
|
||||
from llama_toolchain.safety.providers import available_safety_providers
|
||||
from llama_toolchain.telemetry.api import Telemetry
|
||||
|
||||
from .datatypes import (
|
||||
Api,
|
||||
|
@ -44,7 +42,7 @@ def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
|
|||
|
||||
|
||||
def stack_apis() -> List[Api]:
|
||||
return [Api.inference, Api.safety, Api.agentic_system, Api.memory]
|
||||
return [v for v in Api]
|
||||
|
||||
|
||||
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||
|
@ -55,6 +53,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
Api.safety: Safety,
|
||||
Api.agentic_system: AgenticSystem,
|
||||
Api.memory: Memory,
|
||||
Api.telemetry: Telemetry,
|
||||
}
|
||||
|
||||
for api, protocol in protocols.items():
|
||||
|
@ -82,20 +81,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
|
||||
|
||||
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
inference_providers_by_id = {
|
||||
a.provider_type: a for a in available_inference_providers()
|
||||
}
|
||||
safety_providers_by_id = {a.provider_type: a for a in available_safety_providers()}
|
||||
agentic_system_providers_by_id = {
|
||||
a.provider_type: a for a in available_agentic_system_providers()
|
||||
}
|
||||
ret = {}
|
||||
for api in stack_apis():
|
||||
name = api.name.lower()
|
||||
module = importlib.import_module(f"llama_toolchain.{name}.providers")
|
||||
ret[api] = {
|
||||
"remote": remote_provider_spec(api),
|
||||
**{a.provider_type: a for a in module.available_providers()},
|
||||
}
|
||||
|
||||
ret = {
|
||||
Api.inference: inference_providers_by_id,
|
||||
Api.safety: safety_providers_by_id,
|
||||
Api.agentic_system: agentic_system_providers_by_id,
|
||||
Api.memory: {a.provider_type: a for a in available_memory_providers()},
|
||||
}
|
||||
for k, v in ret.items():
|
||||
v["remote"] = remote_provider_spec(k)
|
||||
return ret
|
||||
|
|
|
@ -21,12 +21,16 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
Api.memory: "meta-reference-faiss",
|
||||
Api.safety: "meta-reference",
|
||||
Api.agentic_system: "meta-reference",
|
||||
Api.telemetry: "console",
|
||||
},
|
||||
),
|
||||
DistributionSpec(
|
||||
distribution_type="remote",
|
||||
description="Point to remote services for all llama stack APIs",
|
||||
providers={x: "remote" for x in Api},
|
||||
providers={
|
||||
**{x: "remote" for x in Api},
|
||||
Api.telemetry: "console",
|
||||
},
|
||||
),
|
||||
DistributionSpec(
|
||||
distribution_type="local-ollama",
|
||||
|
@ -36,6 +40,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
Api.safety: "meta-reference",
|
||||
Api.agentic_system: "meta-reference",
|
||||
Api.memory: "meta-reference-faiss",
|
||||
Api.telemetry: "console",
|
||||
},
|
||||
),
|
||||
DistributionSpec(
|
||||
|
@ -46,6 +51,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
Api.safety: "meta-reference",
|
||||
Api.agentic_system: "meta-reference",
|
||||
Api.memory: "meta-reference-faiss",
|
||||
Api.telemetry: "console",
|
||||
},
|
||||
),
|
||||
DistributionSpec(
|
||||
|
@ -56,6 +62,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
Api.safety: "meta-reference",
|
||||
Api.agentic_system: "meta-reference",
|
||||
Api.memory: "meta-reference-faiss",
|
||||
Api.telemetry: "console",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
|
@ -38,6 +38,13 @@ from pydantic import BaseModel, ValidationError
|
|||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_toolchain.telemetry.tracing import (
|
||||
end_trace,
|
||||
setup_logger,
|
||||
SpanStatus,
|
||||
start_trace,
|
||||
)
|
||||
|
||||
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
from .distribution import api_endpoints, api_providers
|
||||
from .dynamic import instantiate_provider
|
||||
|
@ -88,6 +95,8 @@ async def passthrough(
|
|||
downstream_url: str,
|
||||
downstream_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
await start_trace(request.path, {"downstream_url": downstream_url})
|
||||
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
headers.update(downstream_headers or {})
|
||||
|
@ -95,6 +104,7 @@ async def passthrough(
|
|||
content = await request.body()
|
||||
|
||||
client = httpx.AsyncClient()
|
||||
erred = False
|
||||
try:
|
||||
req = client.build_request(
|
||||
method=request.method,
|
||||
|
@ -120,17 +130,25 @@ async def passthrough(
|
|||
)
|
||||
|
||||
except httpx.ReadTimeout:
|
||||
erred = True
|
||||
return Response(content="Downstream server timed out", status_code=504)
|
||||
except httpx.NetworkError as e:
|
||||
erred = True
|
||||
return Response(content=f"Network error: {str(e)}", status_code=502)
|
||||
except httpx.TooManyRedirects:
|
||||
erred = True
|
||||
return Response(content="Too many redirects", status_code=502)
|
||||
except SSLError as e:
|
||||
erred = True
|
||||
return Response(content=f"SSL error: {str(e)}", status_code=502)
|
||||
except httpx.HTTPStatusError as e:
|
||||
erred = True
|
||||
return Response(content=str(e), status_code=e.response.status_code)
|
||||
except Exception as e:
|
||||
erred = True
|
||||
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
|
||||
finally:
|
||||
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
|
||||
|
||||
|
||||
def handle_sigint(*args, **kwargs):
|
||||
|
@ -159,7 +177,7 @@ def create_dynamic_passthrough(
|
|||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
hints = get_type_hints(func)
|
||||
response_model = hints["return"]
|
||||
response_model = hints.get("return")
|
||||
|
||||
# NOTE: I think it is better to just add a method within each Api
|
||||
# "Protocol" / adapter-impl to tell what sort of a response this request
|
||||
|
@ -170,6 +188,8 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
if is_streaming:
|
||||
|
||||
async def endpoint(**kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
|
@ -187,6 +207,8 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
return StreamingResponse(
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
|
@ -195,6 +217,7 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
else:
|
||||
|
||||
async def endpoint(**kwargs):
|
||||
await start_trace(func.__name__)
|
||||
try:
|
||||
return (
|
||||
await func(**kwargs)
|
||||
|
@ -204,6 +227,8 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
sig = inspect.signature(func)
|
||||
if method == "post":
|
||||
|
@ -293,6 +318,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
provider_specs[api] = providers[provider_type]
|
||||
|
||||
impls = resolve_impls(provider_specs, config)
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
||||
for provider_spec in provider_specs.values():
|
||||
api = provider_spec.api
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import List
|
|||
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def available_inference_providers() -> List[ProviderSpec]:
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
@ -20,8 +21,11 @@ from llama_toolchain.memory.common.vector_store import (
|
|||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
from llama_toolchain.telemetry import tracing
|
||||
from .config import FaissImplConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FaissIndex(EmbeddingIndex):
|
||||
id_by_index: Dict[int, str]
|
||||
|
@ -32,11 +36,12 @@ class FaissIndex(EmbeddingIndex):
|
|||
self.id_by_index = {}
|
||||
self.chunk_by_index = {}
|
||||
|
||||
@tracing.span(name="add_chunks")
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
indexlen = len(self.id_by_index)
|
||||
for i, chunk in enumerate(chunks):
|
||||
self.chunk_by_index[indexlen + i] = chunk
|
||||
print(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}")
|
||||
logger.info(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}")
|
||||
self.id_by_index[indexlen + i] = chunk.document_id
|
||||
|
||||
self.index.add(np.array(embeddings).astype(np.float32))
|
||||
|
|
|
@ -14,7 +14,7 @@ EMBEDDING_DEPS = [
|
|||
]
|
||||
|
||||
|
||||
def available_memory_providers() -> List[ProviderSpec]:
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.memory,
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import List
|
|||
from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
def available_safety_providers() -> List[ProviderSpec]:
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
|
|
|
@ -6,170 +6,126 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
from typing import Any, Dict, Literal, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ExperimentStatus(Enum):
|
||||
NOT_STARTED = "not_started"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
class SpanStatus(Enum):
|
||||
OK = "ok"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Experiment(BaseModel):
|
||||
id: str
|
||||
class Span(BaseModel):
|
||||
span_id: str
|
||||
trace_id: str
|
||||
parent_span_id: Optional[str] = None
|
||||
name: str
|
||||
status: ExperimentStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
metadata: Dict[str, Any]
|
||||
start_time: datetime
|
||||
end_time: Optional[datetime] = None
|
||||
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Run(BaseModel):
|
||||
id: str
|
||||
experiment_id: str
|
||||
status: str
|
||||
started_at: datetime
|
||||
ended_at: Optional[datetime]
|
||||
metadata: Dict[str, Any]
|
||||
class Trace(BaseModel):
|
||||
trace_id: str
|
||||
root_span_id: str
|
||||
start_time: datetime
|
||||
end_time: Optional[datetime] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Metric(BaseModel):
|
||||
name: str
|
||||
value: Union[float, int, str, bool]
|
||||
timestamp: datetime
|
||||
run_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Log(BaseModel):
|
||||
message: str
|
||||
level: str
|
||||
timestamp: datetime
|
||||
additional_info: Dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ArtifactType(Enum):
|
||||
MODEL = "model"
|
||||
DATASET = "dataset"
|
||||
CHECKPOINT = "checkpoint"
|
||||
PLOT = "plot"
|
||||
class EventType(Enum):
|
||||
UNSTRUCTURED_LOG = "unstructured_log"
|
||||
STRUCTURED_LOG = "structured_log"
|
||||
METRIC = "metric"
|
||||
CONFIG = "config"
|
||||
CODE = "code"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Artifact(BaseModel):
|
||||
id: str
|
||||
class LogSeverity(Enum):
|
||||
VERBOSE = "verbose"
|
||||
DEBUG = "debug"
|
||||
INFO = "info"
|
||||
WARN = "warn"
|
||||
ERROR = "error"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class EventCommon(BaseModel):
|
||||
trace_id: str
|
||||
span_id: str
|
||||
timestamp: datetime
|
||||
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UnstructuredLogEvent(EventCommon):
|
||||
type: Literal[EventType.UNSTRUCTURED_LOG.value] = EventType.UNSTRUCTURED_LOG.value
|
||||
message: str
|
||||
severity: LogSeverity
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricEvent(EventCommon):
|
||||
type: Literal[EventType.METRIC.value] = EventType.METRIC.value
|
||||
metric: str # this would be an enum
|
||||
value: Union[int, float]
|
||||
unit: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class StructuredLogType(Enum):
|
||||
SPAN_START = "span_start"
|
||||
SPAN_END = "span_end"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanStartPayload(BaseModel):
|
||||
type: Literal[StructuredLogType.SPAN_START.value] = (
|
||||
StructuredLogType.SPAN_START.value
|
||||
)
|
||||
name: str
|
||||
type: ArtifactType
|
||||
size: int
|
||||
created_at: datetime
|
||||
metadata: Dict[str, Any]
|
||||
parent_span_id: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CreateExperimentRequest(BaseModel):
|
||||
name: str
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
class SpanEndPayload(BaseModel):
|
||||
type: Literal[StructuredLogType.SPAN_END.value] = StructuredLogType.SPAN_END.value
|
||||
status: SpanStatus
|
||||
|
||||
|
||||
StructuredLogPayload = Annotated[
|
||||
Union[
|
||||
SpanStartPayload,
|
||||
SpanEndPayload,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UpdateExperimentRequest(BaseModel):
|
||||
experiment_id: str
|
||||
status: Optional[ExperimentStatus] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
class StructuredLogEvent(EventCommon):
|
||||
type: Literal[EventType.STRUCTURED_LOG.value] = EventType.STRUCTURED_LOG.value
|
||||
payload: StructuredLogPayload
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CreateRunRequest(BaseModel):
|
||||
experiment_id: str
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UpdateRunRequest(BaseModel):
|
||||
run_id: str
|
||||
status: Optional[str] = None
|
||||
ended_at: Optional[datetime] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LogMetricsRequest(BaseModel):
|
||||
run_id: str
|
||||
metrics: List[Metric]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LogMessagesRequest(BaseModel):
|
||||
logs: List[Log]
|
||||
run_id: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UploadArtifactRequest(BaseModel):
|
||||
experiment_id: str
|
||||
name: str
|
||||
artifact_type: str
|
||||
content: bytes
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LogSearchRequest(BaseModel):
|
||||
query: str
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
Event = Annotated[
|
||||
Union[
|
||||
UnstructuredLogEvent,
|
||||
MetricEvent,
|
||||
StructuredLogEvent,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class Telemetry(Protocol):
|
||||
@webmethod(route="/experiments/create")
|
||||
def create_experiment(self, request: CreateExperimentRequest) -> Experiment: ...
|
||||
@webmethod(route="/telemetry/log_event")
|
||||
async def log_event(self, event: Event): ...
|
||||
|
||||
@webmethod(route="/experiments/list")
|
||||
def list_experiments(self) -> List[Experiment]: ...
|
||||
|
||||
@webmethod(route="/experiments/get")
|
||||
def get_experiment(self, experiment_id: str) -> Experiment: ...
|
||||
|
||||
@webmethod(route="/experiments/update")
|
||||
def update_experiment(self, request: UpdateExperimentRequest) -> Experiment: ...
|
||||
|
||||
@webmethod(route="/experiments/create_run")
|
||||
def create_run(self, request: CreateRunRequest) -> Run: ...
|
||||
|
||||
@webmethod(route="/runs/update")
|
||||
def update_run(self, request: UpdateRunRequest) -> Run: ...
|
||||
|
||||
@webmethod(route="/runs/log_metrics")
|
||||
def log_metrics(self, request: LogMetricsRequest) -> None: ...
|
||||
|
||||
@webmethod(route="/runs/metrics", method="GET")
|
||||
def get_metrics(self, run_id: str) -> List[Metric]: ...
|
||||
|
||||
@webmethod(route="/logging/log_messages")
|
||||
def log_messages(self, request: LogMessagesRequest) -> None: ...
|
||||
|
||||
@webmethod(route="/logging/get_logs")
|
||||
def get_logs(self, request: LogSearchRequest) -> List[Log]: ...
|
||||
|
||||
@webmethod(route="/experiments/artifacts/upload")
|
||||
def upload_artifact(self, request: UploadArtifactRequest) -> Artifact: ...
|
||||
|
||||
@webmethod(route="/experiments/artifacts/get")
|
||||
def list_artifacts(self, experiment_id: str) -> List[Artifact]: ...
|
||||
|
||||
@webmethod(route="/artifacts/get")
|
||||
def get_artifact(self, artifact_id: str) -> Artifact: ...
|
||||
@webmethod(route="/telemetry/get_trace", method="GET")
|
||||
async def get_trace(self, trace_id: str) -> Trace: ...
|
||||
|
|
15
llama_toolchain/telemetry/console/__init__.py
Normal file
15
llama_toolchain/telemetry/console/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import ConsoleConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: ConsoleConfig, _deps):
|
||||
from .console import ConsoleTelemetryImpl
|
||||
|
||||
impl = ConsoleTelemetryImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
13
llama_toolchain/telemetry/console/config.py
Normal file
13
llama_toolchain/telemetry/console/config.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConsoleConfig(BaseModel): ...
|
89
llama_toolchain/telemetry/console/console.py
Normal file
89
llama_toolchain/telemetry/console/console.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_toolchain.telemetry.api import * # noqa: F403
|
||||
from .config import ConsoleConfig
|
||||
|
||||
|
||||
class ConsoleTelemetryImpl(Telemetry):
|
||||
def __init__(self, config: ConsoleConfig) -> None:
|
||||
self.config = config
|
||||
self.spans = {}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def log_event(self, event: Event):
|
||||
if (
|
||||
isinstance(event, StructuredLogEvent)
|
||||
and event.payload.type == StructuredLogType.SPAN_START.value
|
||||
):
|
||||
self.spans[event.span_id] = event.payload
|
||||
|
||||
names = []
|
||||
span_id = event.span_id
|
||||
while True:
|
||||
span_payload = self.spans.get(span_id)
|
||||
if not span_payload:
|
||||
break
|
||||
|
||||
names = [span_payload.name] + names
|
||||
span_id = span_payload.parent_span_id
|
||||
|
||||
span_name = ".".join(names) if names else None
|
||||
|
||||
formatted = format_event(event, span_name)
|
||||
if formatted:
|
||||
print(formatted)
|
||||
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
COLORS = {
|
||||
"reset": "\033[0m",
|
||||
"bold": "\033[1m",
|
||||
"dim": "\033[2m",
|
||||
"red": "\033[31m",
|
||||
"green": "\033[32m",
|
||||
"yellow": "\033[33m",
|
||||
"blue": "\033[34m",
|
||||
"magenta": "\033[35m",
|
||||
"cyan": "\033[36m",
|
||||
"white": "\033[37m",
|
||||
}
|
||||
|
||||
SEVERITY_COLORS = {
|
||||
LogSeverity.VERBOSE: COLORS["dim"] + COLORS["white"],
|
||||
LogSeverity.DEBUG: COLORS["cyan"],
|
||||
LogSeverity.INFO: COLORS["green"],
|
||||
LogSeverity.WARN: COLORS["yellow"],
|
||||
LogSeverity.ERROR: COLORS["red"],
|
||||
LogSeverity.CRITICAL: COLORS["bold"] + COLORS["red"],
|
||||
}
|
||||
|
||||
|
||||
def format_event(event: Event, span_name: str) -> Optional[str]:
|
||||
timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3]
|
||||
span = ""
|
||||
if span_name:
|
||||
span = f"{COLORS['magenta']}[{span_name}]{COLORS['reset']} "
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
severity_color = SEVERITY_COLORS.get(event.severity, COLORS["reset"])
|
||||
return (
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
f"{severity_color}[{event.severity.name}]{COLORS['reset']} "
|
||||
f"{span}"
|
||||
f"{event.message}"
|
||||
)
|
||||
|
||||
elif isinstance(event, StructuredLogEvent):
|
||||
return None
|
||||
|
||||
return f"Unknown event type: {event}"
|
21
llama_toolchain/telemetry/providers.py
Normal file
21
llama_toolchain/telemetry/providers.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.telemetry,
|
||||
provider_type="console",
|
||||
pip_packages=[],
|
||||
module="llama_toolchain.telemetry.console",
|
||||
config_class="llama_toolchain.telemetry.console.ConsoleConfig",
|
||||
),
|
||||
]
|
236
llama_toolchain/telemetry/tracing.py
Normal file
236
llama_toolchain/telemetry/tracing.py
Normal file
|
@ -0,0 +1,236 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
from llama_toolchain.telemetry.api import * # noqa: F403
|
||||
|
||||
|
||||
def generate_short_uuid(len: int = 12):
|
||||
full_uuid = uuid.uuid4()
|
||||
uuid_bytes = full_uuid.bytes
|
||||
encoded = base64.urlsafe_b64encode(uuid_bytes)
|
||||
return encoded.rstrip(b"=").decode("ascii")[:len]
|
||||
|
||||
|
||||
CURRENT_TRACE_CONTEXT = None
|
||||
BACKGROUND_LOGGER = None
|
||||
|
||||
|
||||
class BackgroundLogger:
|
||||
def __init__(self, api: Telemetry, capacity: int = 1000):
|
||||
self.api = api
|
||||
self.log_queue = queue.Queue(maxsize=capacity)
|
||||
self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)
|
||||
self.worker_thread.start()
|
||||
|
||||
def log_event(self, event):
|
||||
try:
|
||||
self.log_queue.put_nowait(event)
|
||||
except queue.Full:
|
||||
print("Log queue is full, dropping event")
|
||||
|
||||
def _process_logs(self):
|
||||
while True:
|
||||
try:
|
||||
event = self.log_queue.get()
|
||||
# figure out how to use a thread's native loop
|
||||
asyncio.run(self.api.log_event(event))
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print("Error processing log event")
|
||||
finally:
|
||||
self.log_queue.task_done()
|
||||
|
||||
def __del__(self):
|
||||
self.log_queue.join()
|
||||
|
||||
|
||||
class TraceContext:
|
||||
spans: List[Span] = []
|
||||
|
||||
def __init__(self, logger: BackgroundLogger, trace_id: str):
|
||||
self.logger = logger
|
||||
self.trace_id = trace_id
|
||||
|
||||
def push_span(self, name: str, attributes: Dict[str, Any] = None):
|
||||
current_span = self.get_current_span()
|
||||
span = Span(
|
||||
span_id=generate_short_uuid(),
|
||||
trace_id=self.trace_id,
|
||||
name=name,
|
||||
start_time=datetime.now(),
|
||||
parent_span_id=current_span.span_id if current_span else None,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
self.logger.log_event(
|
||||
StructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
timestamp=span.start_time,
|
||||
attributes=span.attributes,
|
||||
payload=SpanStartPayload(
|
||||
name=span.name,
|
||||
parent_span_id=span.parent_span_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.spans.append(span)
|
||||
|
||||
def pop_span(self, status: SpanStatus = SpanStatus.OK):
|
||||
span = self.spans.pop()
|
||||
if span is not None:
|
||||
self.logger.log_event(
|
||||
StructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
timestamp=span.start_time,
|
||||
attributes=span.attributes,
|
||||
payload=SpanEndPayload(
|
||||
status=status,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def get_current_span(self):
|
||||
return self.spans[-1] if self.spans else None
|
||||
|
||||
|
||||
def setup_logger(api: Telemetry, level: int = logging.INFO):
|
||||
global BACKGROUND_LOGGER
|
||||
|
||||
BACKGROUND_LOGGER = BackgroundLogger(api)
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(level)
|
||||
logger.addHandler(TelemetryHandler())
|
||||
|
||||
|
||||
async def start_trace(name: str, attributes: Dict[str, Any] = None):
|
||||
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||
|
||||
if BACKGROUND_LOGGER is None:
|
||||
print("No Telemetry implementation set. Skipping trace initialization...")
|
||||
return
|
||||
|
||||
trace_id = generate_short_uuid()
|
||||
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
||||
context.push_span(name, {"__root__": True, **(attributes or {})})
|
||||
|
||||
CURRENT_TRACE_CONTEXT = context
|
||||
|
||||
|
||||
async def end_trace(status: SpanStatus = SpanStatus.OK):
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context is None:
|
||||
return
|
||||
|
||||
context.pop_span(status)
|
||||
CURRENT_TRACE_CONTEXT = None
|
||||
|
||||
|
||||
def severity(levelname: str) -> LogSeverity:
|
||||
if levelname == "DEBUG":
|
||||
return LogSeverity.DEBUG
|
||||
elif levelname == "INFO":
|
||||
return LogSeverity.INFO
|
||||
elif levelname == "WARNING":
|
||||
return LogSeverity.WARNING
|
||||
elif levelname == "ERROR":
|
||||
return LogSeverity.ERROR
|
||||
elif levelname == "CRITICAL":
|
||||
return LogSeverity.CRITICAL
|
||||
else:
|
||||
raise ValueError(f"Unknown log level: {levelname}")
|
||||
|
||||
|
||||
# TODO: ideally, the actual emitting should be done inside a separate daemon
|
||||
# process completely isolated from the server
|
||||
class TelemetryHandler(logging.Handler):
|
||||
def emit(self, record: logging.LogRecord):
|
||||
# horrendous hack to avoid logging from asyncio and getting into an infinite loop
|
||||
if record.module in ("asyncio", "selector_events"):
|
||||
return
|
||||
|
||||
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||
|
||||
if BACKGROUND_LOGGER is None:
|
||||
raise RuntimeError("Telemetry API not initialized")
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context is None:
|
||||
return
|
||||
|
||||
span = context.get_current_span()
|
||||
if span is None:
|
||||
return
|
||||
|
||||
BACKGROUND_LOGGER.log_event(
|
||||
UnstructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
timestamp=datetime.now(),
|
||||
message=self.format(record),
|
||||
severity=severity(record.levelname),
|
||||
)
|
||||
)
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
def span(name: str, attributes: Dict[str, Any] = None):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
try:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
context.push_span(name, attributes)
|
||||
result = func(*args, **kwargs)
|
||||
finally:
|
||||
context.pop_span()
|
||||
return result
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
try:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
context.push_span(name, attributes)
|
||||
result = await func(*args, **kwargs)
|
||||
finally:
|
||||
context.pop_span()
|
||||
return result
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper(*args, **kwargs)
|
||||
else:
|
||||
return sync_wrapper(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
Loading…
Add table
Add a link
Reference in a new issue