mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Simplified Telemetry API and tying it to logger (#57)
* Simplified Telemetry API and tying it to logger * small update which adds a METRIC type * move span events one level down into structured log events --------- Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
parent
1433aaf9f7
commit
191cd28831
15 changed files with 524 additions and 162 deletions
|
@ -9,7 +9,7 @@ from typing import List
|
||||||
from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
|
from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
def available_agentic_system_providers() -> List[ProviderSpec]:
|
def available_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.agentic_system,
|
api=Api.agentic_system,
|
||||||
|
|
|
@ -19,6 +19,7 @@ class Api(Enum):
|
||||||
safety = "safety"
|
safety = "safety"
|
||||||
agentic_system = "agentic_system"
|
agentic_system = "agentic_system"
|
||||||
memory = "memory"
|
memory = "memory"
|
||||||
|
telemetry = "telemetry"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -4,17 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api import AgenticSystem
|
from llama_toolchain.agentic_system.api import AgenticSystem
|
||||||
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
|
|
||||||
from llama_toolchain.inference.api import Inference
|
from llama_toolchain.inference.api import Inference
|
||||||
from llama_toolchain.inference.providers import available_inference_providers
|
|
||||||
from llama_toolchain.memory.api import Memory
|
from llama_toolchain.memory.api import Memory
|
||||||
from llama_toolchain.memory.providers import available_memory_providers
|
|
||||||
from llama_toolchain.safety.api import Safety
|
from llama_toolchain.safety.api import Safety
|
||||||
from llama_toolchain.safety.providers import available_safety_providers
|
from llama_toolchain.telemetry.api import Telemetry
|
||||||
|
|
||||||
from .datatypes import (
|
from .datatypes import (
|
||||||
Api,
|
Api,
|
||||||
|
@ -44,7 +42,7 @@ def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
|
||||||
|
|
||||||
|
|
||||||
def stack_apis() -> List[Api]:
|
def stack_apis() -> List[Api]:
|
||||||
return [Api.inference, Api.safety, Api.agentic_system, Api.memory]
|
return [v for v in Api]
|
||||||
|
|
||||||
|
|
||||||
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
|
@ -55,6 +53,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
Api.safety: Safety,
|
Api.safety: Safety,
|
||||||
Api.agentic_system: AgenticSystem,
|
Api.agentic_system: AgenticSystem,
|
||||||
Api.memory: Memory,
|
Api.memory: Memory,
|
||||||
|
Api.telemetry: Telemetry,
|
||||||
}
|
}
|
||||||
|
|
||||||
for api, protocol in protocols.items():
|
for api, protocol in protocols.items():
|
||||||
|
@ -82,20 +81,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
|
|
||||||
|
|
||||||
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||||
inference_providers_by_id = {
|
ret = {}
|
||||||
a.provider_type: a for a in available_inference_providers()
|
for api in stack_apis():
|
||||||
}
|
name = api.name.lower()
|
||||||
safety_providers_by_id = {a.provider_type: a for a in available_safety_providers()}
|
module = importlib.import_module(f"llama_toolchain.{name}.providers")
|
||||||
agentic_system_providers_by_id = {
|
ret[api] = {
|
||||||
a.provider_type: a for a in available_agentic_system_providers()
|
"remote": remote_provider_spec(api),
|
||||||
}
|
**{a.provider_type: a for a in module.available_providers()},
|
||||||
|
}
|
||||||
|
|
||||||
ret = {
|
|
||||||
Api.inference: inference_providers_by_id,
|
|
||||||
Api.safety: safety_providers_by_id,
|
|
||||||
Api.agentic_system: agentic_system_providers_by_id,
|
|
||||||
Api.memory: {a.provider_type: a for a in available_memory_providers()},
|
|
||||||
}
|
|
||||||
for k, v in ret.items():
|
|
||||||
v["remote"] = remote_provider_spec(k)
|
|
||||||
return ret
|
return ret
|
||||||
|
|
|
@ -21,12 +21,16 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
Api.memory: "meta-reference-faiss",
|
Api.memory: "meta-reference-faiss",
|
||||||
Api.safety: "meta-reference",
|
Api.safety: "meta-reference",
|
||||||
Api.agentic_system: "meta-reference",
|
Api.agentic_system: "meta-reference",
|
||||||
|
Api.telemetry: "console",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
DistributionSpec(
|
DistributionSpec(
|
||||||
distribution_type="remote",
|
distribution_type="remote",
|
||||||
description="Point to remote services for all llama stack APIs",
|
description="Point to remote services for all llama stack APIs",
|
||||||
providers={x: "remote" for x in Api},
|
providers={
|
||||||
|
**{x: "remote" for x in Api},
|
||||||
|
Api.telemetry: "console",
|
||||||
|
},
|
||||||
),
|
),
|
||||||
DistributionSpec(
|
DistributionSpec(
|
||||||
distribution_type="local-ollama",
|
distribution_type="local-ollama",
|
||||||
|
@ -36,6 +40,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
Api.safety: "meta-reference",
|
Api.safety: "meta-reference",
|
||||||
Api.agentic_system: "meta-reference",
|
Api.agentic_system: "meta-reference",
|
||||||
Api.memory: "meta-reference-faiss",
|
Api.memory: "meta-reference-faiss",
|
||||||
|
Api.telemetry: "console",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
DistributionSpec(
|
DistributionSpec(
|
||||||
|
@ -46,6 +51,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
Api.safety: "meta-reference",
|
Api.safety: "meta-reference",
|
||||||
Api.agentic_system: "meta-reference",
|
Api.agentic_system: "meta-reference",
|
||||||
Api.memory: "meta-reference-faiss",
|
Api.memory: "meta-reference-faiss",
|
||||||
|
Api.telemetry: "console",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
DistributionSpec(
|
DistributionSpec(
|
||||||
|
@ -56,6 +62,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
Api.safety: "meta-reference",
|
Api.safety: "meta-reference",
|
||||||
Api.agentic_system: "meta-reference",
|
Api.agentic_system: "meta-reference",
|
||||||
Api.memory: "meta-reference-faiss",
|
Api.memory: "meta-reference-faiss",
|
||||||
|
Api.telemetry: "console",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -38,6 +38,13 @@ from pydantic import BaseModel, ValidationError
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_toolchain.telemetry.tracing import (
|
||||||
|
end_trace,
|
||||||
|
setup_logger,
|
||||||
|
SpanStatus,
|
||||||
|
start_trace,
|
||||||
|
)
|
||||||
|
|
||||||
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||||
from .distribution import api_endpoints, api_providers
|
from .distribution import api_endpoints, api_providers
|
||||||
from .dynamic import instantiate_provider
|
from .dynamic import instantiate_provider
|
||||||
|
@ -88,6 +95,8 @@ async def passthrough(
|
||||||
downstream_url: str,
|
downstream_url: str,
|
||||||
downstream_headers: Optional[Dict[str, str]] = None,
|
downstream_headers: Optional[Dict[str, str]] = None,
|
||||||
):
|
):
|
||||||
|
await start_trace(request.path, {"downstream_url": downstream_url})
|
||||||
|
|
||||||
headers = dict(request.headers)
|
headers = dict(request.headers)
|
||||||
headers.pop("host", None)
|
headers.pop("host", None)
|
||||||
headers.update(downstream_headers or {})
|
headers.update(downstream_headers or {})
|
||||||
|
@ -95,6 +104,7 @@ async def passthrough(
|
||||||
content = await request.body()
|
content = await request.body()
|
||||||
|
|
||||||
client = httpx.AsyncClient()
|
client = httpx.AsyncClient()
|
||||||
|
erred = False
|
||||||
try:
|
try:
|
||||||
req = client.build_request(
|
req = client.build_request(
|
||||||
method=request.method,
|
method=request.method,
|
||||||
|
@ -120,17 +130,25 @@ async def passthrough(
|
||||||
)
|
)
|
||||||
|
|
||||||
except httpx.ReadTimeout:
|
except httpx.ReadTimeout:
|
||||||
|
erred = True
|
||||||
return Response(content="Downstream server timed out", status_code=504)
|
return Response(content="Downstream server timed out", status_code=504)
|
||||||
except httpx.NetworkError as e:
|
except httpx.NetworkError as e:
|
||||||
|
erred = True
|
||||||
return Response(content=f"Network error: {str(e)}", status_code=502)
|
return Response(content=f"Network error: {str(e)}", status_code=502)
|
||||||
except httpx.TooManyRedirects:
|
except httpx.TooManyRedirects:
|
||||||
|
erred = True
|
||||||
return Response(content="Too many redirects", status_code=502)
|
return Response(content="Too many redirects", status_code=502)
|
||||||
except SSLError as e:
|
except SSLError as e:
|
||||||
|
erred = True
|
||||||
return Response(content=f"SSL error: {str(e)}", status_code=502)
|
return Response(content=f"SSL error: {str(e)}", status_code=502)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
|
erred = True
|
||||||
return Response(content=str(e), status_code=e.response.status_code)
|
return Response(content=str(e), status_code=e.response.status_code)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
erred = True
|
||||||
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
|
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
|
||||||
|
finally:
|
||||||
|
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
|
||||||
|
|
||||||
|
|
||||||
def handle_sigint(*args, **kwargs):
|
def handle_sigint(*args, **kwargs):
|
||||||
|
@ -159,7 +177,7 @@ def create_dynamic_passthrough(
|
||||||
|
|
||||||
def create_dynamic_typed_route(func: Any, method: str):
|
def create_dynamic_typed_route(func: Any, method: str):
|
||||||
hints = get_type_hints(func)
|
hints = get_type_hints(func)
|
||||||
response_model = hints["return"]
|
response_model = hints.get("return")
|
||||||
|
|
||||||
# NOTE: I think it is better to just add a method within each Api
|
# NOTE: I think it is better to just add a method within each Api
|
||||||
# "Protocol" / adapter-impl to tell what sort of a response this request
|
# "Protocol" / adapter-impl to tell what sort of a response this request
|
||||||
|
@ -170,6 +188,8 @@ def create_dynamic_typed_route(func: Any, method: str):
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
|
|
||||||
async def endpoint(**kwargs):
|
async def endpoint(**kwargs):
|
||||||
|
await start_trace(func.__name__)
|
||||||
|
|
||||||
async def sse_generator(event_gen):
|
async def sse_generator(event_gen):
|
||||||
try:
|
try:
|
||||||
async for item in event_gen:
|
async for item in event_gen:
|
||||||
|
@ -187,6 +207,8 @@ def create_dynamic_typed_route(func: Any, method: str):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
await end_trace()
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||||
|
@ -195,6 +217,7 @@ def create_dynamic_typed_route(func: Any, method: str):
|
||||||
else:
|
else:
|
||||||
|
|
||||||
async def endpoint(**kwargs):
|
async def endpoint(**kwargs):
|
||||||
|
await start_trace(func.__name__)
|
||||||
try:
|
try:
|
||||||
return (
|
return (
|
||||||
await func(**kwargs)
|
await func(**kwargs)
|
||||||
|
@ -204,6 +227,8 @@ def create_dynamic_typed_route(func: Any, method: str):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exception(e)
|
traceback.print_exception(e)
|
||||||
raise translate_exception(e) from e
|
raise translate_exception(e) from e
|
||||||
|
finally:
|
||||||
|
await end_trace()
|
||||||
|
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
if method == "post":
|
if method == "post":
|
||||||
|
@ -293,6 +318,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
provider_specs[api] = providers[provider_type]
|
provider_specs[api] = providers[provider_type]
|
||||||
|
|
||||||
impls = resolve_impls(provider_specs, config)
|
impls = resolve_impls(provider_specs, config)
|
||||||
|
if Api.telemetry in impls:
|
||||||
|
setup_logger(impls[Api.telemetry])
|
||||||
|
|
||||||
for provider_spec in provider_specs.values():
|
for provider_spec in provider_specs.values():
|
||||||
api = provider_spec.api
|
api = provider_spec.api
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import List
|
||||||
from llama_toolchain.core.datatypes import * # noqa: F403
|
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
def available_inference_providers() -> List[ProviderSpec]:
|
def available_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
@ -20,8 +21,11 @@ from llama_toolchain.memory.common.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
)
|
)
|
||||||
|
from llama_toolchain.telemetry import tracing
|
||||||
from .config import FaissImplConfig
|
from .config import FaissImplConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FaissIndex(EmbeddingIndex):
|
class FaissIndex(EmbeddingIndex):
|
||||||
id_by_index: Dict[int, str]
|
id_by_index: Dict[int, str]
|
||||||
|
@ -32,11 +36,12 @@ class FaissIndex(EmbeddingIndex):
|
||||||
self.id_by_index = {}
|
self.id_by_index = {}
|
||||||
self.chunk_by_index = {}
|
self.chunk_by_index = {}
|
||||||
|
|
||||||
|
@tracing.span(name="add_chunks")
|
||||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
indexlen = len(self.id_by_index)
|
indexlen = len(self.id_by_index)
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
self.chunk_by_index[indexlen + i] = chunk
|
self.chunk_by_index[indexlen + i] = chunk
|
||||||
print(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}")
|
logger.info(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}")
|
||||||
self.id_by_index[indexlen + i] = chunk.document_id
|
self.id_by_index[indexlen + i] = chunk.document_id
|
||||||
|
|
||||||
self.index.add(np.array(embeddings).astype(np.float32))
|
self.index.add(np.array(embeddings).astype(np.float32))
|
||||||
|
|
|
@ -14,7 +14,7 @@ EMBEDDING_DEPS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def available_memory_providers() -> List[ProviderSpec]:
|
def available_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import List
|
||||||
from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
|
from llama_toolchain.core.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
def available_safety_providers() -> List[ProviderSpec]:
|
def available_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
|
|
|
@ -6,170 +6,126 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Literal, Optional, Protocol, Union
|
||||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from pydantic import BaseModel
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ExperimentStatus(Enum):
|
class SpanStatus(Enum):
|
||||||
NOT_STARTED = "not_started"
|
OK = "ok"
|
||||||
RUNNING = "running"
|
ERROR = "error"
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Experiment(BaseModel):
|
class Span(BaseModel):
|
||||||
id: str
|
span_id: str
|
||||||
|
trace_id: str
|
||||||
|
parent_span_id: Optional[str] = None
|
||||||
name: str
|
name: str
|
||||||
status: ExperimentStatus
|
start_time: datetime
|
||||||
created_at: datetime
|
end_time: Optional[datetime] = None
|
||||||
updated_at: datetime
|
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||||
metadata: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Run(BaseModel):
|
class Trace(BaseModel):
|
||||||
id: str
|
trace_id: str
|
||||||
experiment_id: str
|
root_span_id: str
|
||||||
status: str
|
start_time: datetime
|
||||||
started_at: datetime
|
end_time: Optional[datetime] = None
|
||||||
ended_at: Optional[datetime]
|
|
||||||
metadata: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Metric(BaseModel):
|
class EventType(Enum):
|
||||||
name: str
|
UNSTRUCTURED_LOG = "unstructured_log"
|
||||||
value: Union[float, int, str, bool]
|
STRUCTURED_LOG = "structured_log"
|
||||||
timestamp: datetime
|
|
||||||
run_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Log(BaseModel):
|
|
||||||
message: str
|
|
||||||
level: str
|
|
||||||
timestamp: datetime
|
|
||||||
additional_info: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ArtifactType(Enum):
|
|
||||||
MODEL = "model"
|
|
||||||
DATASET = "dataset"
|
|
||||||
CHECKPOINT = "checkpoint"
|
|
||||||
PLOT = "plot"
|
|
||||||
METRIC = "metric"
|
METRIC = "metric"
|
||||||
CONFIG = "config"
|
|
||||||
CODE = "code"
|
|
||||||
OTHER = "other"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Artifact(BaseModel):
|
class LogSeverity(Enum):
|
||||||
id: str
|
VERBOSE = "verbose"
|
||||||
|
DEBUG = "debug"
|
||||||
|
INFO = "info"
|
||||||
|
WARN = "warn"
|
||||||
|
ERROR = "error"
|
||||||
|
CRITICAL = "critical"
|
||||||
|
|
||||||
|
|
||||||
|
class EventCommon(BaseModel):
|
||||||
|
trace_id: str
|
||||||
|
span_id: str
|
||||||
|
timestamp: datetime
|
||||||
|
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class UnstructuredLogEvent(EventCommon):
|
||||||
|
type: Literal[EventType.UNSTRUCTURED_LOG.value] = EventType.UNSTRUCTURED_LOG.value
|
||||||
|
message: str
|
||||||
|
severity: LogSeverity
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MetricEvent(EventCommon):
|
||||||
|
type: Literal[EventType.METRIC.value] = EventType.METRIC.value
|
||||||
|
metric: str # this would be an enum
|
||||||
|
value: Union[int, float]
|
||||||
|
unit: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class StructuredLogType(Enum):
|
||||||
|
SPAN_START = "span_start"
|
||||||
|
SPAN_END = "span_end"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SpanStartPayload(BaseModel):
|
||||||
|
type: Literal[StructuredLogType.SPAN_START.value] = (
|
||||||
|
StructuredLogType.SPAN_START.value
|
||||||
|
)
|
||||||
name: str
|
name: str
|
||||||
type: ArtifactType
|
parent_span_id: Optional[str] = None
|
||||||
size: int
|
|
||||||
created_at: datetime
|
|
||||||
metadata: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CreateExperimentRequest(BaseModel):
|
class SpanEndPayload(BaseModel):
|
||||||
name: str
|
type: Literal[StructuredLogType.SPAN_END.value] = StructuredLogType.SPAN_END.value
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
status: SpanStatus
|
||||||
|
|
||||||
|
|
||||||
|
StructuredLogPayload = Annotated[
|
||||||
|
Union[
|
||||||
|
SpanStartPayload,
|
||||||
|
SpanEndPayload,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class UpdateExperimentRequest(BaseModel):
|
class StructuredLogEvent(EventCommon):
|
||||||
experiment_id: str
|
type: Literal[EventType.STRUCTURED_LOG.value] = EventType.STRUCTURED_LOG.value
|
||||||
status: Optional[ExperimentStatus] = None
|
payload: StructuredLogPayload
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
Event = Annotated[
|
||||||
class CreateRunRequest(BaseModel):
|
Union[
|
||||||
experiment_id: str
|
UnstructuredLogEvent,
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
MetricEvent,
|
||||||
|
StructuredLogEvent,
|
||||||
|
],
|
||||||
@json_schema_type
|
Field(discriminator="type"),
|
||||||
class UpdateRunRequest(BaseModel):
|
]
|
||||||
run_id: str
|
|
||||||
status: Optional[str] = None
|
|
||||||
ended_at: Optional[datetime] = None
|
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class LogMetricsRequest(BaseModel):
|
|
||||||
run_id: str
|
|
||||||
metrics: List[Metric]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class LogMessagesRequest(BaseModel):
|
|
||||||
logs: List[Log]
|
|
||||||
run_id: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class UploadArtifactRequest(BaseModel):
|
|
||||||
experiment_id: str
|
|
||||||
name: str
|
|
||||||
artifact_type: str
|
|
||||||
content: bytes
|
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class LogSearchRequest(BaseModel):
|
|
||||||
query: str
|
|
||||||
filters: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class Telemetry(Protocol):
|
class Telemetry(Protocol):
|
||||||
@webmethod(route="/experiments/create")
|
@webmethod(route="/telemetry/log_event")
|
||||||
def create_experiment(self, request: CreateExperimentRequest) -> Experiment: ...
|
async def log_event(self, event: Event): ...
|
||||||
|
|
||||||
@webmethod(route="/experiments/list")
|
@webmethod(route="/telemetry/get_trace", method="GET")
|
||||||
def list_experiments(self) -> List[Experiment]: ...
|
async def get_trace(self, trace_id: str) -> Trace: ...
|
||||||
|
|
||||||
@webmethod(route="/experiments/get")
|
|
||||||
def get_experiment(self, experiment_id: str) -> Experiment: ...
|
|
||||||
|
|
||||||
@webmethod(route="/experiments/update")
|
|
||||||
def update_experiment(self, request: UpdateExperimentRequest) -> Experiment: ...
|
|
||||||
|
|
||||||
@webmethod(route="/experiments/create_run")
|
|
||||||
def create_run(self, request: CreateRunRequest) -> Run: ...
|
|
||||||
|
|
||||||
@webmethod(route="/runs/update")
|
|
||||||
def update_run(self, request: UpdateRunRequest) -> Run: ...
|
|
||||||
|
|
||||||
@webmethod(route="/runs/log_metrics")
|
|
||||||
def log_metrics(self, request: LogMetricsRequest) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/runs/metrics", method="GET")
|
|
||||||
def get_metrics(self, run_id: str) -> List[Metric]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/logging/log_messages")
|
|
||||||
def log_messages(self, request: LogMessagesRequest) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/logging/get_logs")
|
|
||||||
def get_logs(self, request: LogSearchRequest) -> List[Log]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/experiments/artifacts/upload")
|
|
||||||
def upload_artifact(self, request: UploadArtifactRequest) -> Artifact: ...
|
|
||||||
|
|
||||||
@webmethod(route="/experiments/artifacts/get")
|
|
||||||
def list_artifacts(self, experiment_id: str) -> List[Artifact]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/artifacts/get")
|
|
||||||
def get_artifact(self, artifact_id: str) -> Artifact: ...
|
|
||||||
|
|
15
llama_toolchain/telemetry/console/__init__.py
Normal file
15
llama_toolchain/telemetry/console/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import ConsoleConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: ConsoleConfig, _deps):
|
||||||
|
from .console import ConsoleTelemetryImpl
|
||||||
|
|
||||||
|
impl = ConsoleTelemetryImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
13
llama_toolchain/telemetry/console/config.py
Normal file
13
llama_toolchain/telemetry/console/config.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ConsoleConfig(BaseModel): ...
|
89
llama_toolchain/telemetry/console/console.py
Normal file
89
llama_toolchain/telemetry/console/console.py
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from llama_toolchain.telemetry.api import * # noqa: F403
|
||||||
|
from .config import ConsoleConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ConsoleTelemetryImpl(Telemetry):
|
||||||
|
def __init__(self, config: ConsoleConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.spans = {}
|
||||||
|
|
||||||
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
|
async def log_event(self, event: Event):
|
||||||
|
if (
|
||||||
|
isinstance(event, StructuredLogEvent)
|
||||||
|
and event.payload.type == StructuredLogType.SPAN_START.value
|
||||||
|
):
|
||||||
|
self.spans[event.span_id] = event.payload
|
||||||
|
|
||||||
|
names = []
|
||||||
|
span_id = event.span_id
|
||||||
|
while True:
|
||||||
|
span_payload = self.spans.get(span_id)
|
||||||
|
if not span_payload:
|
||||||
|
break
|
||||||
|
|
||||||
|
names = [span_payload.name] + names
|
||||||
|
span_id = span_payload.parent_span_id
|
||||||
|
|
||||||
|
span_name = ".".join(names) if names else None
|
||||||
|
|
||||||
|
formatted = format_event(event, span_name)
|
||||||
|
if formatted:
|
||||||
|
print(formatted)
|
||||||
|
|
||||||
|
async def get_trace(self, trace_id: str) -> Trace:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
COLORS = {
|
||||||
|
"reset": "\033[0m",
|
||||||
|
"bold": "\033[1m",
|
||||||
|
"dim": "\033[2m",
|
||||||
|
"red": "\033[31m",
|
||||||
|
"green": "\033[32m",
|
||||||
|
"yellow": "\033[33m",
|
||||||
|
"blue": "\033[34m",
|
||||||
|
"magenta": "\033[35m",
|
||||||
|
"cyan": "\033[36m",
|
||||||
|
"white": "\033[37m",
|
||||||
|
}
|
||||||
|
|
||||||
|
SEVERITY_COLORS = {
|
||||||
|
LogSeverity.VERBOSE: COLORS["dim"] + COLORS["white"],
|
||||||
|
LogSeverity.DEBUG: COLORS["cyan"],
|
||||||
|
LogSeverity.INFO: COLORS["green"],
|
||||||
|
LogSeverity.WARN: COLORS["yellow"],
|
||||||
|
LogSeverity.ERROR: COLORS["red"],
|
||||||
|
LogSeverity.CRITICAL: COLORS["bold"] + COLORS["red"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def format_event(event: Event, span_name: str) -> Optional[str]:
|
||||||
|
timestamp = event.timestamp.strftime("%H:%M:%S.%f")[:-3]
|
||||||
|
span = ""
|
||||||
|
if span_name:
|
||||||
|
span = f"{COLORS['magenta']}[{span_name}]{COLORS['reset']} "
|
||||||
|
if isinstance(event, UnstructuredLogEvent):
|
||||||
|
severity_color = SEVERITY_COLORS.get(event.severity, COLORS["reset"])
|
||||||
|
return (
|
||||||
|
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||||
|
f"{severity_color}[{event.severity.name}]{COLORS['reset']} "
|
||||||
|
f"{span}"
|
||||||
|
f"{event.message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(event, StructuredLogEvent):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return f"Unknown event type: {event}"
|
21
llama_toolchain/telemetry/providers.py
Normal file
21
llama_toolchain/telemetry/providers.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
def available_providers() -> List[ProviderSpec]:
|
||||||
|
return [
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.telemetry,
|
||||||
|
provider_type="console",
|
||||||
|
pip_packages=[],
|
||||||
|
module="llama_toolchain.telemetry.console",
|
||||||
|
config_class="llama_toolchain.telemetry.console.ConsoleConfig",
|
||||||
|
),
|
||||||
|
]
|
236
llama_toolchain/telemetry/tracing.py
Normal file
236
llama_toolchain/telemetry/tracing.py
Normal file
|
@ -0,0 +1,236 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
from llama_toolchain.telemetry.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
def generate_short_uuid(len: int = 12):
|
||||||
|
full_uuid = uuid.uuid4()
|
||||||
|
uuid_bytes = full_uuid.bytes
|
||||||
|
encoded = base64.urlsafe_b64encode(uuid_bytes)
|
||||||
|
return encoded.rstrip(b"=").decode("ascii")[:len]
|
||||||
|
|
||||||
|
|
||||||
|
CURRENT_TRACE_CONTEXT = None
|
||||||
|
BACKGROUND_LOGGER = None
|
||||||
|
|
||||||
|
|
||||||
|
class BackgroundLogger:
|
||||||
|
def __init__(self, api: Telemetry, capacity: int = 1000):
|
||||||
|
self.api = api
|
||||||
|
self.log_queue = queue.Queue(maxsize=capacity)
|
||||||
|
self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)
|
||||||
|
self.worker_thread.start()
|
||||||
|
|
||||||
|
def log_event(self, event):
|
||||||
|
try:
|
||||||
|
self.log_queue.put_nowait(event)
|
||||||
|
except queue.Full:
|
||||||
|
print("Log queue is full, dropping event")
|
||||||
|
|
||||||
|
def _process_logs(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
event = self.log_queue.get()
|
||||||
|
# figure out how to use a thread's native loop
|
||||||
|
asyncio.run(self.api.log_event(event))
|
||||||
|
except Exception:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
print("Error processing log event")
|
||||||
|
finally:
|
||||||
|
self.log_queue.task_done()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.log_queue.join()
|
||||||
|
|
||||||
|
|
||||||
|
class TraceContext:
|
||||||
|
spans: List[Span] = []
|
||||||
|
|
||||||
|
def __init__(self, logger: BackgroundLogger, trace_id: str):
|
||||||
|
self.logger = logger
|
||||||
|
self.trace_id = trace_id
|
||||||
|
|
||||||
|
def push_span(self, name: str, attributes: Dict[str, Any] = None):
|
||||||
|
current_span = self.get_current_span()
|
||||||
|
span = Span(
|
||||||
|
span_id=generate_short_uuid(),
|
||||||
|
trace_id=self.trace_id,
|
||||||
|
name=name,
|
||||||
|
start_time=datetime.now(),
|
||||||
|
parent_span_id=current_span.span_id if current_span else None,
|
||||||
|
attributes=attributes,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.log_event(
|
||||||
|
StructuredLogEvent(
|
||||||
|
trace_id=span.trace_id,
|
||||||
|
span_id=span.span_id,
|
||||||
|
timestamp=span.start_time,
|
||||||
|
attributes=span.attributes,
|
||||||
|
payload=SpanStartPayload(
|
||||||
|
name=span.name,
|
||||||
|
parent_span_id=span.parent_span_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.spans.append(span)
|
||||||
|
|
||||||
|
def pop_span(self, status: SpanStatus = SpanStatus.OK):
|
||||||
|
span = self.spans.pop()
|
||||||
|
if span is not None:
|
||||||
|
self.logger.log_event(
|
||||||
|
StructuredLogEvent(
|
||||||
|
trace_id=span.trace_id,
|
||||||
|
span_id=span.span_id,
|
||||||
|
timestamp=span.start_time,
|
||||||
|
attributes=span.attributes,
|
||||||
|
payload=SpanEndPayload(
|
||||||
|
status=status,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_current_span(self):
|
||||||
|
return self.spans[-1] if self.spans else None
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(api: Telemetry, level: int = logging.INFO):
|
||||||
|
global BACKGROUND_LOGGER
|
||||||
|
|
||||||
|
BACKGROUND_LOGGER = BackgroundLogger(api)
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(level)
|
||||||
|
logger.addHandler(TelemetryHandler())
|
||||||
|
|
||||||
|
|
||||||
|
async def start_trace(name: str, attributes: Dict[str, Any] = None):
|
||||||
|
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||||
|
|
||||||
|
if BACKGROUND_LOGGER is None:
|
||||||
|
print("No Telemetry implementation set. Skipping trace initialization...")
|
||||||
|
return
|
||||||
|
|
||||||
|
trace_id = generate_short_uuid()
|
||||||
|
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
||||||
|
context.push_span(name, {"__root__": True, **(attributes or {})})
|
||||||
|
|
||||||
|
CURRENT_TRACE_CONTEXT = context
|
||||||
|
|
||||||
|
|
||||||
|
async def end_trace(status: SpanStatus = SpanStatus.OK):
|
||||||
|
global CURRENT_TRACE_CONTEXT
|
||||||
|
|
||||||
|
context = CURRENT_TRACE_CONTEXT
|
||||||
|
if context is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
context.pop_span(status)
|
||||||
|
CURRENT_TRACE_CONTEXT = None
|
||||||
|
|
||||||
|
|
||||||
|
def severity(levelname: str) -> LogSeverity:
|
||||||
|
if levelname == "DEBUG":
|
||||||
|
return LogSeverity.DEBUG
|
||||||
|
elif levelname == "INFO":
|
||||||
|
return LogSeverity.INFO
|
||||||
|
elif levelname == "WARNING":
|
||||||
|
return LogSeverity.WARNING
|
||||||
|
elif levelname == "ERROR":
|
||||||
|
return LogSeverity.ERROR
|
||||||
|
elif levelname == "CRITICAL":
|
||||||
|
return LogSeverity.CRITICAL
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown log level: {levelname}")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: ideally, the actual emitting should be done inside a separate daemon
|
||||||
|
# process completely isolated from the server
|
||||||
|
class TelemetryHandler(logging.Handler):
|
||||||
|
def emit(self, record: logging.LogRecord):
|
||||||
|
# horrendous hack to avoid logging from asyncio and getting into an infinite loop
|
||||||
|
if record.module in ("asyncio", "selector_events"):
|
||||||
|
return
|
||||||
|
|
||||||
|
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||||
|
|
||||||
|
if BACKGROUND_LOGGER is None:
|
||||||
|
raise RuntimeError("Telemetry API not initialized")
|
||||||
|
|
||||||
|
context = CURRENT_TRACE_CONTEXT
|
||||||
|
if context is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
span = context.get_current_span()
|
||||||
|
if span is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
BACKGROUND_LOGGER.log_event(
|
||||||
|
UnstructuredLogEvent(
|
||||||
|
trace_id=span.trace_id,
|
||||||
|
span_id=span.span_id,
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
message=self.format(record),
|
||||||
|
severity=severity(record.levelname),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def span(name: str, attributes: Dict[str, Any] = None):
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
global CURRENT_TRACE_CONTEXT
|
||||||
|
|
||||||
|
context = CURRENT_TRACE_CONTEXT
|
||||||
|
if context:
|
||||||
|
context.push_span(name, attributes)
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
context.pop_span()
|
||||||
|
return result
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
async def async_wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
global CURRENT_TRACE_CONTEXT
|
||||||
|
|
||||||
|
context = CURRENT_TRACE_CONTEXT
|
||||||
|
if context:
|
||||||
|
context.push_span(name, attributes)
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
context.pop_span()
|
||||||
|
return result
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if asyncio.iscoroutinefunction(func):
|
||||||
|
return async_wrapper(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return sync_wrapper(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
Loading…
Add table
Add a link
Reference in a new issue