mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 20:12:33 +00:00
600 lines
23 KiB
Python
600 lines
23 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
"""
|
|
Mock servers for OpenTelemetry E2E testing.
|
|
|
|
This module provides mock servers for testing telemetry:
|
|
- MockOTLPCollector: Receives and stores OTLP telemetry exports
|
|
- MockVLLMServer: Simulates vLLM inference API with valid OpenAI responses
|
|
|
|
These mocks allow E2E testing without external dependencies.
|
|
"""
|
|
|
|
import asyncio
|
|
import http.server
|
|
import json
|
|
import socket
|
|
import threading
|
|
import time
|
|
from collections import defaultdict
|
|
from typing import Any
|
|
|
|
from opentelemetry.proto.collector.metrics.v1.metrics_service_pb2 import (
|
|
ExportMetricsServiceRequest,
|
|
)
|
|
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
|
|
ExportTraceServiceRequest,
|
|
)
|
|
from pydantic import Field
|
|
|
|
from .mock_base import MockServerBase
|
|
|
|
|
|
class MockOTLPCollector(MockServerBase):
|
|
"""
|
|
Mock OTLP collector HTTP server.
|
|
|
|
Receives real OTLP exports from Llama Stack and stores them for verification.
|
|
Runs on localhost:4318 (standard OTLP HTTP port).
|
|
|
|
Usage:
|
|
collector = MockOTLPCollector()
|
|
await collector.await_start()
|
|
# ... run tests ...
|
|
print(f"Received {collector.get_trace_count()} traces")
|
|
collector.stop()
|
|
"""
|
|
|
|
port: int = Field(default=4318, description="Port to run collector on")
|
|
|
|
# Non-Pydantic fields (set after initialization)
|
|
traces: list[dict] = Field(default_factory=list, exclude=True)
|
|
metrics: list[dict] = Field(default_factory=list, exclude=True)
|
|
all_http_requests: list[dict] = Field(default_factory=list, exclude=True) # Track ALL HTTP requests for debugging
|
|
server: Any = Field(default=None, exclude=True)
|
|
server_thread: Any = Field(default=None, exclude=True)
|
|
|
|
def model_post_init(self, __context):
|
|
"""Initialize after Pydantic validation."""
|
|
self.traces = []
|
|
self.metrics = []
|
|
self.server = None
|
|
self.server_thread = None
|
|
|
|
def _create_handler_class(self):
|
|
"""Create the HTTP handler class for this collector instance."""
|
|
collector_self = self
|
|
|
|
class OTLPHandler(http.server.BaseHTTPRequestHandler):
|
|
"""HTTP request handler for OTLP requests."""
|
|
|
|
def log_message(self, format, *args):
|
|
"""Suppress HTTP server logs."""
|
|
pass
|
|
|
|
def do_GET(self): # noqa: N802
|
|
"""Handle GET requests."""
|
|
# No readiness endpoint needed - using await_start() instead
|
|
self.send_response(404)
|
|
self.end_headers()
|
|
|
|
def do_POST(self): # noqa: N802
|
|
"""Handle OTLP POST requests."""
|
|
content_length = int(self.headers.get("Content-Length", 0))
|
|
body = self.rfile.read(content_length) if content_length > 0 else b""
|
|
|
|
# Track ALL requests for debugging
|
|
collector_self.all_http_requests.append(
|
|
{
|
|
"method": "POST",
|
|
"path": self.path,
|
|
"timestamp": time.time(),
|
|
"body_length": len(body),
|
|
}
|
|
)
|
|
|
|
# Store the export request
|
|
if "/v1/traces" in self.path:
|
|
collector_self.traces.append(
|
|
{
|
|
"body": body,
|
|
"timestamp": time.time(),
|
|
}
|
|
)
|
|
elif "/v1/metrics" in self.path:
|
|
collector_self.metrics.append(
|
|
{
|
|
"body": body,
|
|
"timestamp": time.time(),
|
|
}
|
|
)
|
|
|
|
# Always return success (200 OK)
|
|
self.send_response(200)
|
|
self.send_header("Content-Type", "application/json")
|
|
self.end_headers()
|
|
self.wfile.write(b"{}")
|
|
|
|
return OTLPHandler
|
|
|
|
async def await_start(self):
|
|
"""
|
|
Start the OTLP collector and wait until ready.
|
|
|
|
This method is async and can be awaited to ensure the server is ready.
|
|
"""
|
|
# Create handler and start the HTTP server
|
|
handler_class = self._create_handler_class()
|
|
self.server = http.server.HTTPServer(("localhost", self.port), handler_class)
|
|
self.server_thread = threading.Thread(target=self.server.serve_forever, daemon=True)
|
|
self.server_thread.start()
|
|
|
|
# Wait for server to be listening on the port
|
|
for _ in range(10):
|
|
try:
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
result = sock.connect_ex(("localhost", self.port))
|
|
sock.close()
|
|
if result == 0:
|
|
# Port is listening
|
|
return
|
|
except Exception:
|
|
pass
|
|
await asyncio.sleep(0.1)
|
|
|
|
raise RuntimeError(f"OTLP collector failed to start on port {self.port}")
|
|
|
|
def stop(self):
|
|
"""Stop the OTLP collector server."""
|
|
if self.server:
|
|
self.server.shutdown()
|
|
self.server.server_close()
|
|
|
|
def clear(self):
|
|
"""Clear all captured telemetry data."""
|
|
self.traces = []
|
|
self.metrics = []
|
|
|
|
def get_trace_count(self) -> int:
|
|
"""Get number of trace export requests received."""
|
|
return len(self.traces)
|
|
|
|
def get_metric_count(self) -> int:
|
|
"""Get number of metric export requests received."""
|
|
return len(self.metrics)
|
|
|
|
def get_all_traces(self) -> list[dict]:
|
|
"""Get all captured trace exports."""
|
|
return self.traces
|
|
|
|
def get_all_metrics(self) -> list[dict]:
|
|
"""Get all captured metric exports."""
|
|
return self.metrics
|
|
|
|
# -----------------------------
|
|
# Trace parsing helpers
|
|
# -----------------------------
|
|
def parse_traces(self) -> dict[str, list[dict]]:
|
|
"""
|
|
Parse protobuf trace data and return spans grouped by trace ID.
|
|
|
|
Returns:
|
|
Dict mapping trace_id (hex) -> list of span dicts
|
|
"""
|
|
trace_id_to_spans: dict[str, list[dict]] = {}
|
|
|
|
for export in self.traces:
|
|
request = ExportTraceServiceRequest()
|
|
body = export.get("body", b"")
|
|
try:
|
|
request.ParseFromString(body)
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to parse OTLP traces export (len={len(body)}): {e}") from e
|
|
|
|
for resource_span in request.resource_spans:
|
|
for scope_span in resource_span.scope_spans:
|
|
for span in scope_span.spans:
|
|
# span.trace_id is bytes; convert to hex string
|
|
trace_id = (
|
|
span.trace_id.hex() if isinstance(span.trace_id, bytes | bytearray) else str(span.trace_id)
|
|
)
|
|
span_entry = {
|
|
"name": span.name,
|
|
"span_id": span.span_id.hex()
|
|
if isinstance(span.span_id, bytes | bytearray)
|
|
else str(span.span_id),
|
|
"start_time_unix_nano": int(getattr(span, "start_time_unix_nano", 0)),
|
|
"end_time_unix_nano": int(getattr(span, "end_time_unix_nano", 0)),
|
|
}
|
|
trace_id_to_spans.setdefault(trace_id, []).append(span_entry)
|
|
|
|
return trace_id_to_spans
|
|
|
|
def get_all_trace_ids(self) -> set[str]:
|
|
"""Return set of all trace IDs seen so far."""
|
|
return set(self.parse_traces().keys())
|
|
|
|
def get_trace_span_counts(self) -> dict[str, int]:
|
|
"""Return span counts per trace ID."""
|
|
grouped = self.parse_traces()
|
|
return {tid: len(spans) for tid, spans in grouped.items()}
|
|
|
|
def get_new_trace_ids(self, prior_ids: set[str]) -> set[str]:
|
|
"""Return trace IDs that appeared after prior_ids snapshot."""
|
|
return self.get_all_trace_ids() - set(prior_ids)
|
|
|
|
def parse_metrics(self) -> dict[str, list[Any]]:
|
|
"""
|
|
Parse protobuf metric data and return metrics by name.
|
|
|
|
Returns:
|
|
Dict mapping metric names to list of metric data points
|
|
"""
|
|
metrics_by_name = defaultdict(list)
|
|
|
|
for export in self.metrics:
|
|
# Parse the protobuf body
|
|
request = ExportMetricsServiceRequest()
|
|
body = export.get("body", b"")
|
|
try:
|
|
request.ParseFromString(body)
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to parse OTLP metrics export (len={len(body)}): {e}") from e
|
|
|
|
# Extract metrics from the request
|
|
for resource_metric in request.resource_metrics:
|
|
for scope_metric in resource_metric.scope_metrics:
|
|
for metric in scope_metric.metrics:
|
|
metric_name = metric.name
|
|
|
|
# Extract data points based on metric type
|
|
data_points = []
|
|
if metric.HasField("gauge"):
|
|
data_points = list(metric.gauge.data_points)
|
|
elif metric.HasField("sum"):
|
|
data_points = list(metric.sum.data_points)
|
|
elif metric.HasField("histogram"):
|
|
data_points = list(metric.histogram.data_points)
|
|
elif metric.HasField("summary"):
|
|
data_points = list(metric.summary.data_points)
|
|
|
|
metrics_by_name[metric_name].extend(data_points)
|
|
|
|
return dict(metrics_by_name)
|
|
|
|
def get_metric_by_name(self, metric_name: str) -> list[Any]:
|
|
"""
|
|
Get all data points for a specific metric by name.
|
|
|
|
Args:
|
|
metric_name: The name of the metric to retrieve
|
|
|
|
Returns:
|
|
List of data points for the metric, or empty list if not found
|
|
"""
|
|
metrics = self.parse_metrics()
|
|
return metrics.get(metric_name, [])
|
|
|
|
def has_metric(self, metric_name: str) -> bool:
|
|
"""
|
|
Check if a metric with the given name has been captured.
|
|
|
|
Args:
|
|
metric_name: The name of the metric to check
|
|
|
|
Returns:
|
|
True if the metric exists and has data points, False otherwise
|
|
"""
|
|
data_points = self.get_metric_by_name(metric_name)
|
|
return len(data_points) > 0
|
|
|
|
def get_all_metric_names(self) -> list[str]:
|
|
"""
|
|
Get all unique metric names that have been captured.
|
|
|
|
Returns:
|
|
List of metric names
|
|
"""
|
|
return list(self.parse_metrics().keys())
|
|
|
|
|
|
class MockVLLMServer(MockServerBase):
|
|
"""
|
|
Mock vLLM inference server with OpenAI-compatible API.
|
|
|
|
Returns valid OpenAI Python client response objects for:
|
|
- Chat completions (/v1/chat/completions)
|
|
- Text completions (/v1/completions)
|
|
- Model listing (/v1/models)
|
|
|
|
Runs on localhost:8000 (standard vLLM port).
|
|
|
|
Usage:
|
|
server = MockVLLMServer(models=["my-model"])
|
|
await server.await_start()
|
|
# ... make inference calls ...
|
|
print(f"Handled {server.get_request_count()} requests")
|
|
server.stop()
|
|
"""
|
|
|
|
port: int = Field(default=8000, description="Port to run server on")
|
|
models: list[str] = Field(
|
|
default_factory=lambda: ["meta-llama/Llama-3.2-1B-Instruct"], description="List of model IDs to serve"
|
|
)
|
|
|
|
# Non-Pydantic fields
|
|
requests_received: list[dict] = Field(default_factory=list, exclude=True)
|
|
server: Any = Field(default=None, exclude=True)
|
|
server_thread: Any = Field(default=None, exclude=True)
|
|
|
|
def model_post_init(self, __context):
|
|
"""Initialize after Pydantic validation."""
|
|
self.requests_received = []
|
|
self.server = None
|
|
self.server_thread = None
|
|
|
|
def _create_handler_class(self):
|
|
"""Create the HTTP handler class for this vLLM instance."""
|
|
server_self = self
|
|
|
|
class VLLMHandler(http.server.BaseHTTPRequestHandler):
|
|
"""HTTP request handler for vLLM API."""
|
|
|
|
def log_message(self, format, *args):
|
|
"""Suppress HTTP server logs."""
|
|
pass
|
|
|
|
def log_request(self, code="-", size="-"):
|
|
"""Log incoming requests for debugging."""
|
|
print(f"[DEBUG] Mock vLLM received: {self.command} {self.path} -> {code}")
|
|
|
|
def do_GET(self): # noqa: N802
|
|
"""Handle GET requests (models list, health check)."""
|
|
# Log GET requests too
|
|
server_self.requests_received.append(
|
|
{
|
|
"path": self.path,
|
|
"method": "GET",
|
|
"timestamp": time.time(),
|
|
}
|
|
)
|
|
|
|
if self.path == "/v1/models":
|
|
response = self._create_models_list_response()
|
|
self._send_json_response(200, response)
|
|
|
|
elif self.path == "/health" or self.path == "/v1/health":
|
|
self._send_json_response(200, {"status": "healthy"})
|
|
|
|
else:
|
|
self.send_response(404)
|
|
self.end_headers()
|
|
|
|
def do_POST(self): # noqa: N802
|
|
"""Handle POST requests (chat/text completions)."""
|
|
content_length = int(self.headers.get("Content-Length", 0))
|
|
body = self.rfile.read(content_length) if content_length > 0 else b"{}"
|
|
|
|
try:
|
|
request_data = json.loads(body)
|
|
except Exception:
|
|
request_data = {}
|
|
|
|
# Log the request
|
|
server_self.requests_received.append(
|
|
{
|
|
"path": self.path,
|
|
"body": request_data,
|
|
"timestamp": time.time(),
|
|
}
|
|
)
|
|
|
|
# Route to appropriate handler
|
|
if "/chat/completions" in self.path:
|
|
response = self._create_chat_completion_response(request_data)
|
|
if response is not None: # None means already sent (streaming)
|
|
self._send_json_response(200, response)
|
|
|
|
elif "/completions" in self.path:
|
|
response = self._create_text_completion_response(request_data)
|
|
self._send_json_response(200, response)
|
|
|
|
else:
|
|
self._send_json_response(200, {"status": "ok"})
|
|
|
|
# ----------------------------------------------------------------
|
|
# Response Generators
|
|
# **TO MODIFY RESPONSES:** Edit these methods
|
|
# ----------------------------------------------------------------
|
|
|
|
def _create_models_list_response(self) -> dict:
|
|
"""Create OpenAI models list response with configured models."""
|
|
return {
|
|
"object": "list",
|
|
"data": [
|
|
{
|
|
"id": model_id,
|
|
"object": "model",
|
|
"created": int(time.time()),
|
|
"owned_by": "meta",
|
|
}
|
|
for model_id in server_self.models
|
|
],
|
|
}
|
|
|
|
def _create_chat_completion_response(self, request_data: dict) -> dict | None:
|
|
"""
|
|
Create OpenAI ChatCompletion response.
|
|
|
|
Returns a valid response matching openai.types.ChatCompletion.
|
|
Supports both regular and streaming responses.
|
|
Returns None for streaming responses (already sent via SSE).
|
|
"""
|
|
# Check if streaming is requested
|
|
is_streaming = request_data.get("stream", False)
|
|
|
|
if is_streaming:
|
|
# Return SSE streaming response
|
|
self.send_response(200)
|
|
self.send_header("Content-Type", "text/event-stream")
|
|
self.send_header("Cache-Control", "no-cache")
|
|
self.send_header("Connection", "keep-alive")
|
|
self.end_headers()
|
|
|
|
# Send streaming chunks
|
|
chunks = [
|
|
{
|
|
"id": "chatcmpl-test",
|
|
"object": "chat.completion.chunk",
|
|
"created": int(time.time()),
|
|
"model": request_data.get("model", "test"),
|
|
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
|
|
},
|
|
{
|
|
"id": "chatcmpl-test",
|
|
"object": "chat.completion.chunk",
|
|
"created": int(time.time()),
|
|
"model": request_data.get("model", "test"),
|
|
"choices": [{"index": 0, "delta": {"content": "Test "}, "finish_reason": None}],
|
|
},
|
|
{
|
|
"id": "chatcmpl-test",
|
|
"object": "chat.completion.chunk",
|
|
"created": int(time.time()),
|
|
"model": request_data.get("model", "test"),
|
|
"choices": [{"index": 0, "delta": {"content": "streaming "}, "finish_reason": None}],
|
|
},
|
|
{
|
|
"id": "chatcmpl-test",
|
|
"object": "chat.completion.chunk",
|
|
"created": int(time.time()),
|
|
"model": request_data.get("model", "test"),
|
|
"choices": [{"index": 0, "delta": {"content": "response"}, "finish_reason": None}],
|
|
},
|
|
{
|
|
"id": "chatcmpl-test",
|
|
"object": "chat.completion.chunk",
|
|
"created": int(time.time()),
|
|
"model": request_data.get("model", "test"),
|
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
|
},
|
|
]
|
|
|
|
for chunk in chunks:
|
|
self.wfile.write(f"data: {json.dumps(chunk)}\n\n".encode())
|
|
self.wfile.write(b"data: [DONE]\n\n")
|
|
return None # Already sent response
|
|
|
|
# Regular response
|
|
return {
|
|
"id": "chatcmpl-test123",
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": request_data.get("model", "meta-llama/Llama-3.2-1B-Instruct"),
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "This is a test response from mock vLLM server.",
|
|
"tool_calls": None,
|
|
},
|
|
"logprobs": None,
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": 25,
|
|
"completion_tokens": 15,
|
|
"total_tokens": 40,
|
|
"completion_tokens_details": None,
|
|
},
|
|
"system_fingerprint": None,
|
|
"service_tier": None,
|
|
}
|
|
|
|
def _create_text_completion_response(self, request_data: dict) -> dict:
|
|
"""
|
|
Create OpenAI Completion response.
|
|
|
|
Returns a valid response matching openai.types.Completion
|
|
"""
|
|
return {
|
|
"id": "cmpl-test123",
|
|
"object": "text_completion",
|
|
"created": int(time.time()),
|
|
"model": request_data.get("model", "meta-llama/Llama-3.2-1B-Instruct"),
|
|
"choices": [
|
|
{
|
|
"text": "This is a test completion.",
|
|
"index": 0,
|
|
"logprobs": None,
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": 10,
|
|
"completion_tokens": 8,
|
|
"total_tokens": 18,
|
|
"completion_tokens_details": None,
|
|
},
|
|
"system_fingerprint": None,
|
|
}
|
|
|
|
def _send_json_response(self, status_code: int, data: dict):
|
|
"""Helper to send JSON response."""
|
|
self.send_response(status_code)
|
|
self.send_header("Content-Type", "application/json")
|
|
self.end_headers()
|
|
self.wfile.write(json.dumps(data).encode())
|
|
|
|
return VLLMHandler
|
|
|
|
async def await_start(self):
|
|
"""
|
|
Start the vLLM server and wait until ready.
|
|
|
|
This method is async and can be awaited to ensure the server is ready.
|
|
"""
|
|
# Create handler and start the HTTP server
|
|
handler_class = self._create_handler_class()
|
|
self.server = http.server.HTTPServer(("localhost", self.port), handler_class)
|
|
self.server_thread = threading.Thread(target=self.server.serve_forever, daemon=True)
|
|
self.server_thread.start()
|
|
|
|
# Wait for server to be listening on the port
|
|
for _ in range(10):
|
|
try:
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
result = sock.connect_ex(("localhost", self.port))
|
|
sock.close()
|
|
if result == 0:
|
|
# Port is listening
|
|
return
|
|
except Exception:
|
|
pass
|
|
await asyncio.sleep(0.1)
|
|
|
|
raise RuntimeError(f"vLLM server failed to start on port {self.port}")
|
|
|
|
def stop(self):
|
|
"""Stop the vLLM server."""
|
|
if self.server:
|
|
self.server.shutdown()
|
|
self.server.server_close()
|
|
|
|
def clear(self):
|
|
"""Clear request history."""
|
|
self.requests_received = []
|
|
|
|
def get_request_count(self) -> int:
|
|
"""Get number of requests received."""
|
|
return len(self.requests_received)
|
|
|
|
def get_all_requests(self) -> list[dict]:
|
|
"""Get all received requests with their bodies."""
|
|
return self.requests_received
|