fix(pr specific): passes pre-commit

This commit is contained in:
Emilio Garcia 2025-10-03 12:35:09 -04:00
parent 4aa2dc110d
commit 2b7a765d02
20 changed files with 547 additions and 516 deletions

View file

@ -24,7 +24,7 @@ class MockServerBase(BaseModel):
async def await_start(self):
# Start server and wait until ready
...
def stop(self):
# Stop server and cleanup
...
@ -49,29 +49,29 @@ Add to `servers.py`:
```python
class MockRedisServer(MockServerBase):
"""Mock Redis server."""
port: int = Field(default=6379)
# Non-Pydantic fields
server: Any = Field(default=None, exclude=True)
def model_post_init(self, __context):
self.server = None
async def await_start(self):
"""Start Redis mock and wait until ready."""
# Start your server
self.server = create_redis_server(self.port)
self.server.start()
# Wait for port to be listening
for _ in range(10):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if sock.connect_ex(('localhost', self.port)) == 0:
if sock.connect_ex(("localhost", self.port)) == 0:
sock.close()
return # Ready!
await asyncio.sleep(0.1)
def stop(self):
if self.server:
self.server.stop()
@ -101,11 +101,11 @@ The harness automatically:
## Benefits
**Parallel Startup** - All servers start simultaneously
**Type-Safe** - Pydantic validation
**Simple** - Just implement 2 methods
**Fast** - No HTTP polling, direct port checking
**Clean** - Async/await pattern
**Parallel Startup** - All servers start simultaneously
**Type-Safe** - Pydantic validation
**Simple** - Just implement 2 methods
**Fast** - No HTTP polling, direct port checking
**Clean** - Async/await pattern
## Usage in Tests
@ -116,6 +116,7 @@ def mock_servers():
yield servers
stop_mock_servers(servers)
# Access specific servers
@pytest.fixture(scope="module")
def mock_redis(mock_servers):

View file

@ -14,9 +14,9 @@ This module provides:
- Mock server harness for parallel async startup
"""
from .harness import MockServerConfig, start_mock_servers_async, stop_mock_servers
from .mock_base import MockServerBase
from .servers import MockOTLPCollector, MockVLLMServer
from .harness import MockServerConfig, start_mock_servers_async, stop_mock_servers
__all__ = [
"MockServerBase",
@ -26,4 +26,3 @@ __all__ = [
"start_mock_servers_async",
"stop_mock_servers",
]

View file

@ -14,7 +14,7 @@ HOW TO ADD A NEW MOCK SERVER:
"""
import asyncio
from typing import Any, Dict, List
from typing import Any
from pydantic import BaseModel, Field
@ -24,10 +24,10 @@ from .mock_base import MockServerBase
class MockServerConfig(BaseModel):
"""
Configuration for a mock server to start.
**TO ADD A NEW MOCK SERVER:**
Just create a MockServerConfig instance with your server class.
Example:
MockServerConfig(
name="Mock MyService",
@ -35,73 +35,72 @@ class MockServerConfig(BaseModel):
init_kwargs={"port": 9000, "config_param": "value"},
)
"""
model_config = {"arbitrary_types_allowed": True}
name: str = Field(description="Display name for logging")
server_class: type = Field(description="Mock server class (must inherit from MockServerBase)")
init_kwargs: Dict[str, Any] = Field(default_factory=dict, description="Kwargs to pass to server constructor")
init_kwargs: dict[str, Any] = Field(default_factory=dict, description="Kwargs to pass to server constructor")
async def start_mock_servers_async(mock_servers_config: List[MockServerConfig]) -> Dict[str, MockServerBase]:
async def start_mock_servers_async(mock_servers_config: list[MockServerConfig]) -> dict[str, MockServerBase]:
"""
Start all mock servers in parallel and wait for them to be ready.
**HOW IT WORKS:**
1. Creates all server instances
2. Calls await_start() on all servers in parallel
3. Returns when all are ready
**SIMPLE TO USE:**
servers = await start_mock_servers_async([config1, config2, ...])
Args:
mock_servers_config: List of mock server configurations
Returns:
Dict mapping server name to server instance
"""
servers = {}
start_tasks = []
# Create all servers and prepare start tasks
for config in mock_servers_config:
server = config.server_class(**config.init_kwargs)
servers[config.name] = server
start_tasks.append(server.await_start())
# Start all servers in parallel
try:
await asyncio.gather(*start_tasks)
# Print readiness confirmation
for name in servers.keys():
print(f"[INFO] {name} ready")
except Exception as e:
# If any server fails, stop all servers
for server in servers.values():
try:
server.stop()
except:
except Exception:
pass
raise RuntimeError(f"Failed to start mock servers: {e}")
raise RuntimeError(f"Failed to start mock servers: {e}") from None
return servers
def stop_mock_servers(servers: Dict[str, Any]):
def stop_mock_servers(servers: dict[str, Any]):
"""
Stop all mock servers.
Args:
servers: Dict of server instances from start_mock_servers_async()
"""
for name, server in servers.items():
try:
if hasattr(server, 'get_request_count'):
if hasattr(server, "get_request_count"):
print(f"\n[INFO] {name} received {server.get_request_count()} requests")
server.stop()
except Exception as e:
print(f"[WARN] Error stopping {name}: {e}")

View file

@ -10,25 +10,25 @@ Base class for mock servers with async startup support.
All mock servers should inherit from MockServerBase and implement await_start().
"""
import asyncio
from abc import abstractmethod
from pydantic import BaseModel, Field
from pydantic import BaseModel
class MockServerBase(BaseModel):
"""
Pydantic base model for mock servers.
**TO CREATE A NEW MOCK SERVER:**
1. Inherit from this class
2. Implement async def await_start(self)
3. Implement def stop(self)
4. Done!
Example:
class MyMockServer(MockServerBase):
port: int = 8080
async def await_start(self):
# Start your server
self.server = create_server()
@ -36,34 +36,33 @@ class MockServerBase(BaseModel):
# Wait until ready (can check internal state, no HTTP needed)
while not self.server.is_listening():
await asyncio.sleep(0.1)
def stop(self):
if self.server:
self.server.stop()
"""
model_config = {"arbitrary_types_allowed": True}
@abstractmethod
async def await_start(self):
"""
Start the server and wait until it's ready.
This method should:
1. Start the server (synchronous or async)
2. Wait until the server is fully ready to accept requests
3. Return when ready
Subclasses can check internal state directly - no HTTP polling needed!
"""
...
@abstractmethod
def stop(self):
"""
Stop the server and clean up resources.
This method should gracefully shut down the server.
"""
...

View file

@ -20,7 +20,7 @@ import json
import socket
import threading
import time
from typing import Any, Dict, List
from typing import Any
from pydantic import Field
@ -30,10 +30,10 @@ from .mock_base import MockServerBase
class MockOTLPCollector(MockServerBase):
"""
Mock OTLP collector HTTP server.
Receives real OTLP exports from Llama Stack and stores them for verification.
Runs on localhost:4318 (standard OTLP HTTP port).
Usage:
collector = MockOTLPCollector()
await collector.await_start()
@ -41,115 +41,119 @@ class MockOTLPCollector(MockServerBase):
print(f"Received {collector.get_trace_count()} traces")
collector.stop()
"""
port: int = Field(default=4318, description="Port to run collector on")
# Non-Pydantic fields (set after initialization)
traces: List[Dict] = Field(default_factory=list, exclude=True)
metrics: List[Dict] = Field(default_factory=list, exclude=True)
traces: list[dict] = Field(default_factory=list, exclude=True)
metrics: list[dict] = Field(default_factory=list, exclude=True)
server: Any = Field(default=None, exclude=True)
server_thread: Any = Field(default=None, exclude=True)
def model_post_init(self, __context):
"""Initialize after Pydantic validation."""
self.traces = []
self.metrics = []
self.server = None
self.server_thread = None
def _create_handler_class(self):
"""Create the HTTP handler class for this collector instance."""
collector_self = self
class OTLPHandler(http.server.BaseHTTPRequestHandler):
"""HTTP request handler for OTLP requests."""
def log_message(self, format, *args):
"""Suppress HTTP server logs."""
pass
def do_GET(self):
def do_GET(self): # noqa: N802
"""Handle GET requests."""
# No readiness endpoint needed - using await_start() instead
self.send_response(404)
self.end_headers()
def do_POST(self):
def do_POST(self): # noqa: N802
"""Handle OTLP POST requests."""
content_length = int(self.headers.get('Content-Length', 0))
body = self.rfile.read(content_length) if content_length > 0 else b''
content_length = int(self.headers.get("Content-Length", 0))
body = self.rfile.read(content_length) if content_length > 0 else b""
# Store the export request
if '/v1/traces' in self.path:
collector_self.traces.append({
'body': body,
'timestamp': time.time(),
})
elif '/v1/metrics' in self.path:
collector_self.metrics.append({
'body': body,
'timestamp': time.time(),
})
if "/v1/traces" in self.path:
collector_self.traces.append(
{
"body": body,
"timestamp": time.time(),
}
)
elif "/v1/metrics" in self.path:
collector_self.metrics.append(
{
"body": body,
"timestamp": time.time(),
}
)
# Always return success (200 OK)
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(b'{}')
self.wfile.write(b"{}")
return OTLPHandler
async def await_start(self):
"""
Start the OTLP collector and wait until ready.
This method is async and can be awaited to ensure the server is ready.
"""
# Create handler and start the HTTP server
handler_class = self._create_handler_class()
self.server = http.server.HTTPServer(('localhost', self.port), handler_class)
self.server = http.server.HTTPServer(("localhost", self.port), handler_class)
self.server_thread = threading.Thread(target=self.server.serve_forever, daemon=True)
self.server_thread.start()
# Wait for server to be listening on the port
for _ in range(10):
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(('localhost', self.port))
result = sock.connect_ex(("localhost", self.port))
sock.close()
if result == 0:
# Port is listening
return
except:
except Exception:
pass
await asyncio.sleep(0.1)
raise RuntimeError(f"OTLP collector failed to start on port {self.port}")
def stop(self):
"""Stop the OTLP collector server."""
if self.server:
self.server.shutdown()
self.server.server_close()
def clear(self):
"""Clear all captured telemetry data."""
self.traces = []
self.metrics = []
def get_trace_count(self) -> int:
"""Get number of trace export requests received."""
return len(self.traces)
def get_metric_count(self) -> int:
"""Get number of metric export requests received."""
return len(self.metrics)
def get_all_traces(self) -> List[Dict]:
def get_all_traces(self) -> list[dict]:
"""Get all captured trace exports."""
return self.traces
def get_all_metrics(self) -> List[Dict]:
def get_all_metrics(self) -> list[dict]:
"""Get all captured metric exports."""
return self.metrics
@ -157,14 +161,14 @@ class MockOTLPCollector(MockServerBase):
class MockVLLMServer(MockServerBase):
"""
Mock vLLM inference server with OpenAI-compatible API.
Returns valid OpenAI Python client response objects for:
- Chat completions (/v1/chat/completions)
- Text completions (/v1/completions)
- Model listing (/v1/models)
Runs on localhost:8000 (standard vLLM port).
Usage:
server = MockVLLMServer(models=["my-model"])
await server.await_start()
@ -172,94 +176,97 @@ class MockVLLMServer(MockServerBase):
print(f"Handled {server.get_request_count()} requests")
server.stop()
"""
port: int = Field(default=8000, description="Port to run server on")
models: List[str] = Field(
default_factory=lambda: ["meta-llama/Llama-3.2-1B-Instruct"],
description="List of model IDs to serve"
models: list[str] = Field(
default_factory=lambda: ["meta-llama/Llama-3.2-1B-Instruct"], description="List of model IDs to serve"
)
# Non-Pydantic fields
requests_received: List[Dict] = Field(default_factory=list, exclude=True)
requests_received: list[dict] = Field(default_factory=list, exclude=True)
server: Any = Field(default=None, exclude=True)
server_thread: Any = Field(default=None, exclude=True)
def model_post_init(self, __context):
"""Initialize after Pydantic validation."""
self.requests_received = []
self.server = None
self.server_thread = None
def _create_handler_class(self):
"""Create the HTTP handler class for this vLLM instance."""
server_self = self
class VLLMHandler(http.server.BaseHTTPRequestHandler):
"""HTTP request handler for vLLM API."""
def log_message(self, format, *args):
"""Suppress HTTP server logs."""
pass
def log_request(self, code='-', size='-'):
def log_request(self, code="-", size="-"):
"""Log incoming requests for debugging."""
print(f"[DEBUG] Mock vLLM received: {self.command} {self.path} -> {code}")
def do_GET(self):
def do_GET(self): # noqa: N802
"""Handle GET requests (models list, health check)."""
# Log GET requests too
server_self.requests_received.append({
'path': self.path,
'method': 'GET',
'timestamp': time.time(),
})
if self.path == '/v1/models':
server_self.requests_received.append(
{
"path": self.path,
"method": "GET",
"timestamp": time.time(),
}
)
if self.path == "/v1/models":
response = self._create_models_list_response()
self._send_json_response(200, response)
elif self.path == '/health' or self.path == '/v1/health':
elif self.path == "/health" or self.path == "/v1/health":
self._send_json_response(200, {"status": "healthy"})
else:
self.send_response(404)
self.end_headers()
def do_POST(self):
def do_POST(self): # noqa: N802
"""Handle POST requests (chat/text completions)."""
content_length = int(self.headers.get('Content-Length', 0))
body = self.rfile.read(content_length) if content_length > 0 else b'{}'
content_length = int(self.headers.get("Content-Length", 0))
body = self.rfile.read(content_length) if content_length > 0 else b"{}"
try:
request_data = json.loads(body)
except:
except Exception:
request_data = {}
# Log the request
server_self.requests_received.append({
'path': self.path,
'body': request_data,
'timestamp': time.time(),
})
server_self.requests_received.append(
{
"path": self.path,
"body": request_data,
"timestamp": time.time(),
}
)
# Route to appropriate handler
if '/chat/completions' in self.path:
if "/chat/completions" in self.path:
response = self._create_chat_completion_response(request_data)
self._send_json_response(200, response)
elif '/completions' in self.path:
elif "/completions" in self.path:
response = self._create_text_completion_response(request_data)
self._send_json_response(200, response)
else:
self._send_json_response(200, {"status": "ok"})
# ----------------------------------------------------------------
# Response Generators
# **TO MODIFY RESPONSES:** Edit these methods
# ----------------------------------------------------------------
def _create_models_list_response(self) -> Dict:
def _create_models_list_response(self) -> dict:
"""Create OpenAI models list response with configured models."""
return {
"object": "list",
@ -271,13 +278,13 @@ class MockVLLMServer(MockServerBase):
"owned_by": "meta",
}
for model_id in server_self.models
]
],
}
def _create_chat_completion_response(self, request_data: Dict) -> Dict:
def _create_chat_completion_response(self, request_data: dict) -> dict:
"""
Create OpenAI ChatCompletion response.
Returns a valid response matching openai.types.ChatCompletion
"""
return {
@ -285,16 +292,18 @@ class MockVLLMServer(MockServerBase):
"object": "chat.completion",
"created": int(time.time()),
"model": request_data.get("model", "meta-llama/Llama-3.2-1B-Instruct"),
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "This is a test response from mock vLLM server.",
"tool_calls": None,
},
"logprobs": None,
"finish_reason": "stop",
}],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "This is a test response from mock vLLM server.",
"tool_calls": None,
},
"logprobs": None,
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 25,
"completion_tokens": 15,
@ -304,11 +313,11 @@ class MockVLLMServer(MockServerBase):
"system_fingerprint": None,
"service_tier": None,
}
def _create_text_completion_response(self, request_data: Dict) -> Dict:
def _create_text_completion_response(self, request_data: dict) -> dict:
"""
Create OpenAI Completion response.
Returns a valid response matching openai.types.Completion
"""
return {
@ -316,12 +325,14 @@ class MockVLLMServer(MockServerBase):
"object": "text_completion",
"created": int(time.time()),
"model": request_data.get("model", "meta-llama/Llama-3.2-1B-Instruct"),
"choices": [{
"text": "This is a test completion.",
"index": 0,
"logprobs": None,
"finish_reason": "stop",
}],
"choices": [
{
"text": "This is a test completion.",
"index": 0,
"logprobs": None,
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 8,
@ -330,58 +341,57 @@ class MockVLLMServer(MockServerBase):
},
"system_fingerprint": None,
}
def _send_json_response(self, status_code: int, data: Dict):
def _send_json_response(self, status_code: int, data: dict):
"""Helper to send JSON response."""
self.send_response(status_code)
self.send_header('Content-Type', 'application/json')
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(data).encode())
return VLLMHandler
async def await_start(self):
"""
Start the vLLM server and wait until ready.
This method is async and can be awaited to ensure the server is ready.
"""
# Create handler and start the HTTP server
handler_class = self._create_handler_class()
self.server = http.server.HTTPServer(('localhost', self.port), handler_class)
self.server = http.server.HTTPServer(("localhost", self.port), handler_class)
self.server_thread = threading.Thread(target=self.server.serve_forever, daemon=True)
self.server_thread.start()
# Wait for server to be listening on the port
for _ in range(10):
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(('localhost', self.port))
result = sock.connect_ex(("localhost", self.port))
sock.close()
if result == 0:
# Port is listening
return
except:
except Exception:
pass
await asyncio.sleep(0.1)
raise RuntimeError(f"vLLM server failed to start on port {self.port}")
def stop(self):
"""Stop the vLLM server."""
if self.server:
self.server.shutdown()
self.server.server_close()
def clear(self):
"""Clear request history."""
self.requests_received = []
def get_request_count(self) -> int:
"""Get number of requests received."""
return len(self.requests_received)
def get_all_requests(self) -> List[Dict]:
def get_all_requests(self) -> list[dict]:
"""Get all received requests with their bodies."""
return self.requests_received