mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-24 00:47:00 +00:00
267 lines
10 KiB
Python
267 lines
10 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
"""
|
|
Simple benchmark script for Llama Stack with OpenAI API compatibility.
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import os
|
|
import random
|
|
import statistics
|
|
import time
|
|
from typing import Tuple
|
|
import aiohttp
|
|
|
|
|
|
class BenchmarkStats:
|
|
def __init__(self):
|
|
self.response_times = []
|
|
self.ttft_times = []
|
|
self.chunks_received = []
|
|
self.errors = []
|
|
self.success_count = 0
|
|
self.total_requests = 0
|
|
self.concurrent_users = 0
|
|
self.start_time = None
|
|
self.end_time = None
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def add_result(self, response_time: float, chunks: int, ttft: float = None, error: str = None):
|
|
async with self._lock:
|
|
self.total_requests += 1
|
|
if error:
|
|
self.errors.append(error)
|
|
else:
|
|
self.success_count += 1
|
|
self.response_times.append(response_time)
|
|
self.chunks_received.append(chunks)
|
|
if ttft is not None:
|
|
self.ttft_times.append(ttft)
|
|
|
|
def print_summary(self):
|
|
if not self.response_times:
|
|
print("No successful requests to report")
|
|
if self.errors:
|
|
print(f"Total errors: {len(self.errors)}")
|
|
print("First 5 errors:")
|
|
for error in self.errors[:5]:
|
|
print(f" {error}")
|
|
return
|
|
|
|
total_time = self.end_time - self.start_time
|
|
success_rate = (self.success_count / self.total_requests) * 100
|
|
|
|
print(f"\n{'='*60}")
|
|
print(f"BENCHMARK RESULTS")
|
|
print(f"{'='*60}")
|
|
print(f"Total time: {total_time:.2f}s")
|
|
print(f"Concurrent users: {self.concurrent_users}")
|
|
print(f"Total requests: {self.total_requests}")
|
|
print(f"Successful requests: {self.success_count}")
|
|
print(f"Failed requests: {len(self.errors)}")
|
|
print(f"Success rate: {success_rate:.1f}%")
|
|
print(f"Requests per second: {self.success_count / total_time:.2f}")
|
|
|
|
print(f"\nResponse Time Statistics:")
|
|
print(f" Mean: {statistics.mean(self.response_times):.3f}s")
|
|
print(f" Median: {statistics.median(self.response_times):.3f}s")
|
|
print(f" Min: {min(self.response_times):.3f}s")
|
|
print(f" Max: {max(self.response_times):.3f}s")
|
|
|
|
if len(self.response_times) > 1:
|
|
print(f" Std Dev: {statistics.stdev(self.response_times):.3f}s")
|
|
|
|
percentiles = [50, 90, 95, 99]
|
|
sorted_times = sorted(self.response_times)
|
|
print(f"\nPercentiles:")
|
|
for p in percentiles:
|
|
idx = int(len(sorted_times) * p / 100) - 1
|
|
idx = max(0, min(idx, len(sorted_times) - 1))
|
|
print(f" P{p}: {sorted_times[idx]:.3f}s")
|
|
|
|
if self.ttft_times:
|
|
print(f"\nTime to First Token (TTFT) Statistics:")
|
|
print(f" Mean: {statistics.mean(self.ttft_times):.3f}s")
|
|
print(f" Median: {statistics.median(self.ttft_times):.3f}s")
|
|
print(f" Min: {min(self.ttft_times):.3f}s")
|
|
print(f" Max: {max(self.ttft_times):.3f}s")
|
|
|
|
if len(self.ttft_times) > 1:
|
|
print(f" Std Dev: {statistics.stdev(self.ttft_times):.3f}s")
|
|
|
|
sorted_ttft = sorted(self.ttft_times)
|
|
print(f"\nTTFT Percentiles:")
|
|
for p in percentiles:
|
|
idx = int(len(sorted_ttft) * p / 100) - 1
|
|
idx = max(0, min(idx, len(sorted_ttft) - 1))
|
|
print(f" P{p}: {sorted_ttft[idx]:.3f}s")
|
|
|
|
if self.chunks_received:
|
|
print(f"\nStreaming Statistics:")
|
|
print(f" Mean chunks per response: {statistics.mean(self.chunks_received):.1f}")
|
|
print(f" Total chunks received: {sum(self.chunks_received)}")
|
|
|
|
if self.errors:
|
|
print(f"\nErrors (showing first 5):")
|
|
for error in self.errors[:5]:
|
|
print(f" {error}")
|
|
|
|
|
|
class LlamaStackBenchmark:
|
|
def __init__(self, base_url: str, model_id: str):
|
|
self.base_url = base_url.rstrip('/')
|
|
self.model_id = model_id
|
|
self.headers = {"Content-Type": "application/json"}
|
|
self.test_messages = [
|
|
[{"role": "user", "content": "Hi"}],
|
|
[{"role": "user", "content": "What is the capital of France?"}],
|
|
[{"role": "user", "content": "Explain quantum physics in simple terms."}],
|
|
[{"role": "user", "content": "Write a short story about a robot learning to paint."}],
|
|
[
|
|
{"role": "user", "content": "What is machine learning?"},
|
|
{"role": "assistant", "content": "Machine learning is a subset of AI..."},
|
|
{"role": "user", "content": "Can you give me a practical example?"}
|
|
]
|
|
]
|
|
|
|
|
|
async def make_async_streaming_request(self) -> Tuple[float, int, float | None, str | None]:
|
|
"""Make a single async streaming chat completion request."""
|
|
messages = random.choice(self.test_messages)
|
|
payload = {
|
|
"model": self.model_id,
|
|
"messages": messages,
|
|
"stream": True,
|
|
"max_tokens": 100
|
|
}
|
|
|
|
start_time = time.time()
|
|
chunks_received = 0
|
|
ttft = None
|
|
error = None
|
|
|
|
session = aiohttp.ClientSession()
|
|
|
|
try:
|
|
async with session.post(
|
|
f"{self.base_url}/chat/completions",
|
|
headers=self.headers,
|
|
json=payload,
|
|
timeout=aiohttp.ClientTimeout(total=30)
|
|
) as response:
|
|
if response.status == 200:
|
|
async for line in response.content:
|
|
if line:
|
|
line_str = line.decode('utf-8').strip()
|
|
if line_str.startswith('data: '):
|
|
chunks_received += 1
|
|
if ttft is None:
|
|
ttft = time.time() - start_time
|
|
if line_str == 'data: [DONE]':
|
|
break
|
|
|
|
if chunks_received == 0:
|
|
error = "No streaming chunks received"
|
|
else:
|
|
text = await response.text()
|
|
error = f"HTTP {response.status}: {text[:100]}"
|
|
|
|
except Exception as e:
|
|
error = f"Request error: {str(e)}"
|
|
finally:
|
|
await session.close()
|
|
|
|
response_time = time.time() - start_time
|
|
return response_time, chunks_received, ttft, error
|
|
|
|
|
|
async def run_benchmark(self, duration: int, concurrent_users: int) -> BenchmarkStats:
|
|
"""Run benchmark using async requests for specified duration."""
|
|
stats = BenchmarkStats()
|
|
stats.concurrent_users = concurrent_users
|
|
stats.start_time = time.time()
|
|
|
|
print(f"Starting benchmark: {duration}s duration, {concurrent_users} concurrent users")
|
|
print(f"Target URL: {self.base_url}/chat/completions")
|
|
print(f"Model: {self.model_id}")
|
|
|
|
connector = aiohttp.TCPConnector(limit=concurrent_users)
|
|
async with aiohttp.ClientSession(connector=connector) as session:
|
|
|
|
async def worker(worker_id: int):
|
|
"""Worker that sends requests sequentially until canceled."""
|
|
request_count = 0
|
|
while True:
|
|
try:
|
|
response_time, chunks, ttft, error = await self.make_async_streaming_request()
|
|
await stats.add_result(response_time, chunks, ttft, error)
|
|
request_count += 1
|
|
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
await stats.add_result(0, 0, None, f"Worker {worker_id} error: {str(e)}")
|
|
|
|
# Progress reporting task
|
|
async def progress_reporter():
|
|
last_report_time = time.time()
|
|
while True:
|
|
try:
|
|
await asyncio.sleep(1) # Report every second
|
|
if time.time() >= last_report_time + 10: # Report every 10 seconds
|
|
elapsed = time.time() - stats.start_time
|
|
print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s")
|
|
last_report_time = time.time()
|
|
except asyncio.CancelledError:
|
|
break
|
|
|
|
# Spawn concurrent workers
|
|
tasks = [asyncio.create_task(worker(i)) for i in range(concurrent_users)]
|
|
progress_task = asyncio.create_task(progress_reporter())
|
|
tasks.append(progress_task)
|
|
|
|
# Wait for duration then cancel all tasks
|
|
await asyncio.sleep(duration)
|
|
|
|
for task in tasks:
|
|
task.cancel()
|
|
|
|
# Wait for all tasks to complete
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
stats.end_time = time.time()
|
|
return stats
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Llama Stack Benchmark Tool")
|
|
parser.add_argument("--base-url", default=os.getenv("BENCHMARK_BASE_URL", "http://localhost:8000/v1/openai/v1"),
|
|
help="Base URL for the API (default: http://localhost:8000/v1/openai/v1)")
|
|
parser.add_argument("--model", default=os.getenv("INFERENCE_MODEL", "test-model"),
|
|
help="Model ID to use for requests")
|
|
parser.add_argument("--duration", type=int, default=60,
|
|
help="Duration in seconds to run benchmark (default: 60)")
|
|
parser.add_argument("--concurrent", type=int, default=10,
|
|
help="Number of concurrent users (default: 10)")
|
|
|
|
args = parser.parse_args()
|
|
|
|
benchmark = LlamaStackBenchmark(args.base_url, args.model)
|
|
|
|
try:
|
|
stats = asyncio.run(benchmark.run_benchmark(args.duration, args.concurrent))
|
|
stats.print_summary()
|
|
|
|
except KeyboardInterrupt:
|
|
print("\nBenchmark interrupted by user")
|
|
except Exception as e:
|
|
print(f"Benchmark failed: {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|