fix(perf): make batches tests finish 30x faster (#3834)

In replay mode, inference is instantenous. We don't need to wait 15
seconds for the batch to be done. Fixing polling to do exp backoff makes
things work super fast.
This commit is contained in:
Ashwin Bharambe 2025-10-17 00:16:44 -07:00 committed by GitHub
parent cd152f4240
commit 4c9d944380
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -70,10 +70,15 @@ class BatchHelper:
):
"""Wait for a batch to reach a terminal status.
Uses exponential backoff polling strategy for efficient waiting:
- Starts with short intervals (0.1s) for fast batches (e.g., replay mode)
- Doubles interval each iteration up to a maximum
- Adapts automatically to both fast and slow batch processing
Args:
batch_id: The batch ID to monitor
max_wait_time: Maximum time to wait in seconds (default: 60 seconds)
sleep_interval: Time to sleep between checks in seconds (default: 1/10th of max_wait_time, min 1s, max 15s)
sleep_interval: If provided, uses fixed interval instead of exponential backoff
expected_statuses: Set of expected terminal statuses (default: {"completed"})
timeout_action: Action on timeout - "fail" (pytest.fail) or "skip" (pytest.skip)
@ -84,10 +89,6 @@ class BatchHelper:
pytest.Failed: If batch reaches an unexpected status or timeout_action is "fail"
pytest.Skipped: If timeout_action is "skip" on timeout or unexpected status
"""
if sleep_interval is None:
# Default to 1/10th of max_wait_time, with min 1s and max 15s
sleep_interval = max(1, min(15, max_wait_time // 10))
if expected_statuses is None:
expected_statuses = {"completed"}
@ -95,6 +96,15 @@ class BatchHelper:
unexpected_statuses = terminal_statuses - expected_statuses
start_time = time.time()
# Use exponential backoff if no explicit sleep_interval provided
if sleep_interval is None:
current_interval = 0.1 # Start with 100ms
max_interval = 10.0 # Cap at 10 seconds
else:
current_interval = sleep_interval
max_interval = sleep_interval
while time.time() - start_time < max_wait_time:
current_batch = self.client.batches.retrieve(batch_id)
@ -107,7 +117,11 @@ class BatchHelper:
else:
pytest.fail(error_msg)
time.sleep(sleep_interval)
time.sleep(current_interval)
# Exponential backoff: double the interval each time, up to max
if sleep_interval is None:
current_interval = min(current_interval * 2, max_interval)
timeout_msg = f"Batch did not reach expected status {expected_statuses} within {max_wait_time} seconds"
if timeout_action == "skip":