From 6812a194dbd4f45165267d59b940933c8ed2dd7c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 16 Aug 2023 12:28:47 -0700 Subject: [PATCH] update model load testing --- litellm/testing.py | 18 +++++++++--------- pyproject.toml | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/litellm/testing.py b/litellm/testing.py index b3d786fa80..ebec952f2a 100644 --- a/litellm/testing.py +++ b/litellm/testing.py @@ -19,7 +19,8 @@ def testing_batch_completion(*args, **kwargs): args_modified[0] = model["model"] else: kwargs_modified["model"] = model["model"] if isinstance(model, dict) and "model" in model else model # if model is a dictionary get it's value else assume it's a string - kwargs_modified["custom_llm_provider"] = model["custom_llm_provider"] if isinstance(model, dict) and "model" in model else None + kwargs_modified["custom_llm_provider"] = model["custom_llm_provider"] if isinstance(model, dict) and "custom_llm_provider" in model else None + kwargs_modified["custom_api_base"] = model["custom_api_base"] if isinstance(model, dict) and "custom_api_base" in model else None for message_list in batch_messages: if len(args) > 1: args_modified[1] = message_list @@ -70,22 +71,21 @@ def duration_test_model(original_function): return wrapper_function @duration_test_model -def load_test_model(model: str, custom_llm_provider: str = None, custom_api_base: str = None, prompt: str = None, num_calls: int = None, request_timeout: int = None): - test_prompt = "Hey, how's it going" +def load_test_model(models: list, prompt: str = None, num_calls: int = None): test_calls = 100 - if prompt: - test_prompt = prompt if num_calls: test_calls = num_calls - messages = [[{"role": "user", "content": test_prompt}] for _ in range(test_calls)] + input_prompt = prompt if prompt else "Hey, how's it going?" + messages = [{"role": "user", "content": prompt}] if prompt else [{"role": "user", "content": input_prompt}] + full_message_list = [messages for _ in range(test_calls)] # call it as many times as set by user to load test models start_time = time.time() try: - results = testing_batch_completion(models=[model], messages=messages, custom_llm_provider=custom_llm_provider, custom_api_base = custom_api_base, force_timeout=request_timeout) + results = testing_batch_completion(models=models, messages=full_message_list) end_time = time.time() response_time = end_time - start_time - return {"total_response_time": response_time, "calls_made": test_calls, "prompt": test_prompt, "results": results} + return {"total_response_time": response_time, "calls_made": test_calls, "prompt": input_prompt, "results": results} except Exception as e: traceback.print_exc() end_time = time.time() response_time = end_time - start_time - return {"total_response_time": response_time, "calls_made": test_calls, "prompt": test_prompt, "exception": e} \ No newline at end of file + return {"total_response_time": response_time, "calls_made": test_calls, "prompt": input_prompt, "exception": e} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0545d8dc88..3578354929 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.403" +version = "0.1.404" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"