mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 04:00:42 +00:00
fix(perf): make batches tests finish 30x faster
This commit is contained in:
parent
cd152f4240
commit
b8511abc29
1 changed files with 20 additions and 6 deletions
|
|
@ -70,10 +70,15 @@ class BatchHelper:
|
|||
):
|
||||
"""Wait for a batch to reach a terminal status.
|
||||
|
||||
Uses exponential backoff polling strategy for efficient waiting:
|
||||
- Starts with short intervals (0.1s) for fast batches (e.g., replay mode)
|
||||
- Doubles interval each iteration up to a maximum
|
||||
- Adapts automatically to both fast and slow batch processing
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID to monitor
|
||||
max_wait_time: Maximum time to wait in seconds (default: 60 seconds)
|
||||
sleep_interval: Time to sleep between checks in seconds (default: 1/10th of max_wait_time, min 1s, max 15s)
|
||||
sleep_interval: If provided, uses fixed interval instead of exponential backoff
|
||||
expected_statuses: Set of expected terminal statuses (default: {"completed"})
|
||||
timeout_action: Action on timeout - "fail" (pytest.fail) or "skip" (pytest.skip)
|
||||
|
||||
|
|
@ -84,10 +89,6 @@ class BatchHelper:
|
|||
pytest.Failed: If batch reaches an unexpected status or timeout_action is "fail"
|
||||
pytest.Skipped: If timeout_action is "skip" on timeout or unexpected status
|
||||
"""
|
||||
if sleep_interval is None:
|
||||
# Default to 1/10th of max_wait_time, with min 1s and max 15s
|
||||
sleep_interval = max(1, min(15, max_wait_time // 10))
|
||||
|
||||
if expected_statuses is None:
|
||||
expected_statuses = {"completed"}
|
||||
|
||||
|
|
@ -95,6 +96,15 @@ class BatchHelper:
|
|||
unexpected_statuses = terminal_statuses - expected_statuses
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Use exponential backoff if no explicit sleep_interval provided
|
||||
if sleep_interval is None:
|
||||
current_interval = 0.1 # Start with 100ms
|
||||
max_interval = 10.0 # Cap at 10 seconds
|
||||
else:
|
||||
current_interval = sleep_interval
|
||||
max_interval = sleep_interval
|
||||
|
||||
while time.time() - start_time < max_wait_time:
|
||||
current_batch = self.client.batches.retrieve(batch_id)
|
||||
|
||||
|
|
@ -107,7 +117,11 @@ class BatchHelper:
|
|||
else:
|
||||
pytest.fail(error_msg)
|
||||
|
||||
time.sleep(sleep_interval)
|
||||
time.sleep(current_interval)
|
||||
|
||||
# Exponential backoff: double the interval each time, up to max
|
||||
if sleep_interval is None:
|
||||
current_interval = min(current_interval * 2, max_interval)
|
||||
|
||||
timeout_msg = f"Batch did not reach expected status {expected_statuses} within {max_wait_time} seconds"
|
||||
if timeout_action == "skip":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue