add model load testing functionality

This commit is contained in:
Krrish Dholakia 2023-08-11 17:59:51 -07:00
parent a2efd32f5c
commit 211e1edfcb
8 changed files with 35 additions and 8 deletions

View file

@ -113,7 +113,7 @@ open_ai_embedding_models = [
]
from .timeout import timeout
from .utils import client, logging, exception_type, get_optional_params, modify_integration, token_counter, cost_per_token, completion_cost
from .utils import client, logging, exception_type, get_optional_params, modify_integration, token_counter, cost_per_token, completion_cost, load_test_model
from .main import * # Import all the symbols from main.py
from .integrations import *
from openai.error import AuthenticationError, InvalidRequestError, RateLimitError, ServiceUnavailableError, OpenAIError

View file

@ -59,7 +59,7 @@ def completion(
# params to identify the model
model=model, replicate=replicate, hugging_face=hugging_face, together_ai=together_ai
)
if azure == True:
if azure == True or custom_llm_provider == "azure": # [TODO]: remove azure=True flag, move to 'custom_llm_provider' approach
# azure configs
openai.api_type = "azure"
openai.api_base = litellm.api_base if litellm.api_base is not None else get_secret("AZURE_API_BASE")
@ -153,7 +153,7 @@ def completion(
model_response["model"] = model
model_response["usage"] = response["usage"]
response = model_response
elif "replicate" in model or replicate == True:
elif "replicate" in model or replicate == True or custom_llm_provider == "replicate":
# import replicate/if it fails then pip install replicate
install_and_import("replicate")
import replicate
@ -256,7 +256,7 @@ def completion(
}
response = model_response
elif model in litellm.openrouter_models:
elif model in litellm.openrouter_models or custom_llm_provider == "openrouter":
openai.api_type = "openai"
# not sure if this will work after someone first uses another API
openai.api_base = litellm.api_base if litellm.api_base is not None else "https://openrouter.ai/api/v1"
@ -338,7 +338,7 @@ def completion(
"total_tokens": prompt_tokens + completion_tokens
}
response = model_response
elif hugging_face == True:
elif hugging_face == True or custom_llm_provider == "huggingface":
import requests
API_URL = f"https://api-inference.huggingface.co/models/{model}"
HF_TOKEN = get_secret("HF_TOKEN")
@ -364,7 +364,7 @@ def completion(
"total_tokens": prompt_tokens + completion_tokens
}
response = model_response
elif together_ai == True:
elif together_ai == True or custom_llm_provider == "together_ai":
import requests
TOGETHER_AI_TOKEN = get_secret("TOGETHER_AI_TOKEN")
headers = {"Authorization": f"Bearer {TOGETHER_AI_TOKEN}"}
@ -430,7 +430,7 @@ def completion(
## LOGGING
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
args = locals()
raise ValueError(f"No valid completion model args passed in - {args}")
raise ValueError(f"Invalid completion model args passed in. Check your input - {args}")
return response
except Exception as e:
## LOGGING

View file

@ -0,0 +1,8 @@
import sys, os
import traceback
sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
import litellm
from litellm import load_test_model
result = load_test_model(model="gpt-3.5-turbo", num_calls=5)
print(result)

View file

@ -302,6 +302,25 @@ def get_optional_params(
return optional_params
return optional_params
def load_test_model(model: str, custom_llm_provider: str = None, prompt: str = None, num_calls: int = None):
test_prompt = "Hey, how's it going"
test_calls = 100
if prompt:
test_prompt = prompt
if num_calls:
test_calls = num_calls
messages = [[{"role": "user", "content": test_prompt}] for _ in range(test_calls)]
start_time = time.time()
try:
litellm.batch_completion(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
end_time = time.time()
response_time = end_time - start_time
return {"total_response_time": response_time, "calls_made": 100, "status": "success", "exception": None}
except Exception as e:
end_time = time.time()
response_time = end_time - start_time
return {"total_response_time": response_time, "calls_made": 100, "status": "failed", "exception": e}
def set_callbacks(callback_list):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient
try:

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "0.1.381"
version = "0.1.382"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"