mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
add model load testing functionality
This commit is contained in:
parent
a2efd32f5c
commit
211e1edfcb
8 changed files with 35 additions and 8 deletions
|
@ -113,7 +113,7 @@ open_ai_embedding_models = [
|
|||
]
|
||||
|
||||
from .timeout import timeout
|
||||
from .utils import client, logging, exception_type, get_optional_params, modify_integration, token_counter, cost_per_token, completion_cost
|
||||
from .utils import client, logging, exception_type, get_optional_params, modify_integration, token_counter, cost_per_token, completion_cost, load_test_model
|
||||
from .main import * # Import all the symbols from main.py
|
||||
from .integrations import *
|
||||
from openai.error import AuthenticationError, InvalidRequestError, RateLimitError, ServiceUnavailableError, OpenAIError
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -59,7 +59,7 @@ def completion(
|
|||
# params to identify the model
|
||||
model=model, replicate=replicate, hugging_face=hugging_face, together_ai=together_ai
|
||||
)
|
||||
if azure == True:
|
||||
if azure == True or custom_llm_provider == "azure": # [TODO]: remove azure=True flag, move to 'custom_llm_provider' approach
|
||||
# azure configs
|
||||
openai.api_type = "azure"
|
||||
openai.api_base = litellm.api_base if litellm.api_base is not None else get_secret("AZURE_API_BASE")
|
||||
|
@ -153,7 +153,7 @@ def completion(
|
|||
model_response["model"] = model
|
||||
model_response["usage"] = response["usage"]
|
||||
response = model_response
|
||||
elif "replicate" in model or replicate == True:
|
||||
elif "replicate" in model or replicate == True or custom_llm_provider == "replicate":
|
||||
# import replicate/if it fails then pip install replicate
|
||||
install_and_import("replicate")
|
||||
import replicate
|
||||
|
@ -256,7 +256,7 @@ def completion(
|
|||
}
|
||||
response = model_response
|
||||
|
||||
elif model in litellm.openrouter_models:
|
||||
elif model in litellm.openrouter_models or custom_llm_provider == "openrouter":
|
||||
openai.api_type = "openai"
|
||||
# not sure if this will work after someone first uses another API
|
||||
openai.api_base = litellm.api_base if litellm.api_base is not None else "https://openrouter.ai/api/v1"
|
||||
|
@ -338,7 +338,7 @@ def completion(
|
|||
"total_tokens": prompt_tokens + completion_tokens
|
||||
}
|
||||
response = model_response
|
||||
elif hugging_face == True:
|
||||
elif hugging_face == True or custom_llm_provider == "huggingface":
|
||||
import requests
|
||||
API_URL = f"https://api-inference.huggingface.co/models/{model}"
|
||||
HF_TOKEN = get_secret("HF_TOKEN")
|
||||
|
@ -364,7 +364,7 @@ def completion(
|
|||
"total_tokens": prompt_tokens + completion_tokens
|
||||
}
|
||||
response = model_response
|
||||
elif together_ai == True:
|
||||
elif together_ai == True or custom_llm_provider == "together_ai":
|
||||
import requests
|
||||
TOGETHER_AI_TOKEN = get_secret("TOGETHER_AI_TOKEN")
|
||||
headers = {"Authorization": f"Bearer {TOGETHER_AI_TOKEN}"}
|
||||
|
@ -430,7 +430,7 @@ def completion(
|
|||
## LOGGING
|
||||
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
|
||||
args = locals()
|
||||
raise ValueError(f"No valid completion model args passed in - {args}")
|
||||
raise ValueError(f"Invalid completion model args passed in. Check your input - {args}")
|
||||
return response
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
|
|
8
litellm/tests/test_load_test_model.py
Normal file
8
litellm/tests/test_load_test_model.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
import sys, os
|
||||
import traceback
|
||||
sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm import load_test_model
|
||||
|
||||
result = load_test_model(model="gpt-3.5-turbo", num_calls=5)
|
||||
print(result)
|
|
@ -302,6 +302,25 @@ def get_optional_params(
|
|||
return optional_params
|
||||
return optional_params
|
||||
|
||||
def load_test_model(model: str, custom_llm_provider: str = None, prompt: str = None, num_calls: int = None):
|
||||
test_prompt = "Hey, how's it going"
|
||||
test_calls = 100
|
||||
if prompt:
|
||||
test_prompt = prompt
|
||||
if num_calls:
|
||||
test_calls = num_calls
|
||||
messages = [[{"role": "user", "content": test_prompt}] for _ in range(test_calls)]
|
||||
start_time = time.time()
|
||||
try:
|
||||
litellm.batch_completion(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
|
||||
end_time = time.time()
|
||||
response_time = end_time - start_time
|
||||
return {"total_response_time": response_time, "calls_made": 100, "status": "success", "exception": None}
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
response_time = end_time - start_time
|
||||
return {"total_response_time": response_time, "calls_made": 100, "status": "failed", "exception": e}
|
||||
|
||||
def set_callbacks(callback_list):
|
||||
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient
|
||||
try:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "litellm"
|
||||
version = "0.1.381"
|
||||
version = "0.1.382"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
authors = ["BerriAI"]
|
||||
license = "MIT License"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue