litellm/litellm/testing.py
2023-08-18 11:36:06 -07:00

137 lines
5.1 KiB
Python

import litellm
import time
from concurrent.futures import ThreadPoolExecutor
import traceback
def testing_batch_completion(*args, **kwargs):
try:
batch_models = (
args[0] if len(args) > 0 else kwargs.pop("models")
) ## expected input format- ["gpt-3.5-turbo", {"model": "qvv0xeq", "custom_llm_provider"="baseten"}...]
batch_messages = args[1] if len(args) > 1 else kwargs.pop("messages")
results = []
completions = []
exceptions = []
times = []
with ThreadPoolExecutor() as executor:
for model in batch_models:
kwargs_modified = dict(kwargs)
args_modified = list(args)
if len(args) > 0:
args_modified[0] = model["model"]
else:
kwargs_modified["model"] = (
model["model"]
if isinstance(model, dict) and "model" in model
else model
) # if model is a dictionary get it's value else assume it's a string
kwargs_modified["custom_llm_provider"] = (
model["custom_llm_provider"]
if isinstance(model, dict) and "custom_llm_provider" in model
else None
)
kwargs_modified["custom_api_base"] = (
model["custom_api_base"]
if isinstance(model, dict) and "custom_api_base" in model
else None
)
for message_list in batch_messages:
if len(args) > 1:
args_modified[1] = message_list
future = executor.submit(
litellm.completion, *args_modified, **kwargs_modified
)
else:
kwargs_modified["messages"] = message_list
future = executor.submit(
litellm.completion, *args_modified, **kwargs_modified
)
completions.append((future, message_list))
# Retrieve the results and calculate elapsed time for each completion call
for completion in completions:
future, message_list = completion
start_time = time.time()
try:
result = future.result()
end_time = time.time()
elapsed_time = end_time - start_time
result_dict = {
"status": "succeeded",
"response": future.result(),
"prompt": message_list,
"response_time": elapsed_time,
}
results.append(result_dict)
except Exception as e:
end_time = time.time()
elapsed_time = end_time - start_time
result_dict = {
"status": "failed",
"response": e,
"response_time": elapsed_time,
}
results.append(result_dict)
return results
except:
traceback.print_exc()
def duration_test_model(original_function):
def wrapper_function(*args, **kwargs):
# Code to be executed before the original function
duration = kwargs.pop("duration", None)
interval = kwargs.pop("interval", None)
results = []
if duration and interval:
start_time = time.time()
end_time = start_time + duration # default to 1hr duration
while time.time() < end_time:
result = original_function(*args, **kwargs)
results.append(result)
time.sleep(interval)
else:
result = original_function(*args, **kwargs)
results = result
return results
# Return the wrapper function
return wrapper_function
@duration_test_model
def load_test_model(models: list, prompt: str = "", num_calls: int = 0):
test_calls = 100
if num_calls:
test_calls = num_calls
input_prompt = prompt if prompt else "Hey, how's it going?"
messages = (
[{"role": "user", "content": prompt}]
if prompt
else [{"role": "user", "content": input_prompt}]
)
full_message_list = [
messages for _ in range(test_calls)
] # call it as many times as set by user to load test models
start_time = time.time()
try:
results = testing_batch_completion(models=models, messages=full_message_list)
end_time = time.time()
response_time = end_time - start_time
return {
"total_response_time": response_time,
"calls_made": test_calls,
"prompt": input_prompt,
"results": results,
}
except Exception as e:
traceback.print_exc()
end_time = time.time()
response_time = end_time - start_time
return {
"total_response_time": response_time,
"calls_made": test_calls,
"prompt": input_prompt,
"exception": e,
}