mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-18 15:27:16 +00:00
fix(perf): make batches tests finish 30x faster (#3834)
In replay mode, inference is instantenous. We don't need to wait 15 seconds for the batch to be done. Fixing polling to do exp backoff makes things work super fast.
This commit is contained in:
parent
cd152f4240
commit
4c9d944380
1 changed files with 20 additions and 6 deletions
|
@ -70,10 +70,15 @@ class BatchHelper:
|
|||
):
|
||||
"""Wait for a batch to reach a terminal status.
|
||||
|
||||
Uses exponential backoff polling strategy for efficient waiting:
|
||||
- Starts with short intervals (0.1s) for fast batches (e.g., replay mode)
|
||||
- Doubles interval each iteration up to a maximum
|
||||
- Adapts automatically to both fast and slow batch processing
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID to monitor
|
||||
max_wait_time: Maximum time to wait in seconds (default: 60 seconds)
|
||||
sleep_interval: Time to sleep between checks in seconds (default: 1/10th of max_wait_time, min 1s, max 15s)
|
||||
sleep_interval: If provided, uses fixed interval instead of exponential backoff
|
||||
expected_statuses: Set of expected terminal statuses (default: {"completed"})
|
||||
timeout_action: Action on timeout - "fail" (pytest.fail) or "skip" (pytest.skip)
|
||||
|
||||
|
@ -84,10 +89,6 @@ class BatchHelper:
|
|||
pytest.Failed: If batch reaches an unexpected status or timeout_action is "fail"
|
||||
pytest.Skipped: If timeout_action is "skip" on timeout or unexpected status
|
||||
"""
|
||||
if sleep_interval is None:
|
||||
# Default to 1/10th of max_wait_time, with min 1s and max 15s
|
||||
sleep_interval = max(1, min(15, max_wait_time // 10))
|
||||
|
||||
if expected_statuses is None:
|
||||
expected_statuses = {"completed"}
|
||||
|
||||
|
@ -95,6 +96,15 @@ class BatchHelper:
|
|||
unexpected_statuses = terminal_statuses - expected_statuses
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Use exponential backoff if no explicit sleep_interval provided
|
||||
if sleep_interval is None:
|
||||
current_interval = 0.1 # Start with 100ms
|
||||
max_interval = 10.0 # Cap at 10 seconds
|
||||
else:
|
||||
current_interval = sleep_interval
|
||||
max_interval = sleep_interval
|
||||
|
||||
while time.time() - start_time < max_wait_time:
|
||||
current_batch = self.client.batches.retrieve(batch_id)
|
||||
|
||||
|
@ -107,7 +117,11 @@ class BatchHelper:
|
|||
else:
|
||||
pytest.fail(error_msg)
|
||||
|
||||
time.sleep(sleep_interval)
|
||||
time.sleep(current_interval)
|
||||
|
||||
# Exponential backoff: double the interval each time, up to max
|
||||
if sleep_interval is None:
|
||||
current_interval = min(current_interval * 2, max_interval)
|
||||
|
||||
timeout_msg = f"Batch did not reach expected status {expected_statuses} within {max_wait_time} seconds"
|
||||
if timeout_action == "skip":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue