improve error message returned if model not passed in

This commit is contained in:
Krrish Dholakia 2023-09-09 11:18:10 -07:00
parent 63c10c2695
commit a9cab12a47
6 changed files with 30 additions and 31 deletions

View file

@ -4,32 +4,24 @@
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
# Get the current directory of the script
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get the parent directory by joining the current directory with '..'
parent_dir = os.path.join(current_dir, "../..")
# Add the parent directory to the system path
sys.path.append(parent_dir)
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import embedding, completion
litellm.set_verbose = True
litellm.success_callback = ["posthog"]
litellm.failure_callback = ["slack", "sentry", "posthog"]
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
model_val = None
def test_completion_with_no_model():
# test on empty
with pytest.raises(ValueError):
response = completion(messages=messages)
def test_completion_with_empty_model():
# test on empty
@ -40,14 +32,14 @@ def test_completion_with_empty_model():
pass
# bad key
temp_key = os.environ.get("OPENAI_API_KEY")
os.environ["OPENAI_API_KEY"] = "bad-key"
# test on openai completion call
try:
response = completion(model="gpt-3.5-turbo", messages=messages)
print(f"response: {response}")
except:
print(f"error occurred: {traceback.format_exc()}")
pass
os.environ["OPENAI_API_KEY"] = str(temp_key) # this passes linting#5
# # bad key
# temp_key = os.environ.get("OPENAI_API_KEY")
# os.environ["OPENAI_API_KEY"] = "bad-key"
# # test on openai completion call
# try:
# response = completion(model="gpt-3.5-turbo", messages=messages)
# print(f"response: {response}")
# except:
# print(f"error occurred: {traceback.format_exc()}")
# pass
# os.environ["OPENAI_API_KEY"] = str(temp_key) # this passes linting#5

View file

@ -434,6 +434,12 @@ def exception_logging(
def client(original_function):
global liteDebuggerClient, get_all_keys
def check_args(*args, **kwargs):
try:
model = args[0] if len(args) > 0 else kwargs["model"]
except:
raise ValueError("model param not passed in.")
def function_setup(
start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@ -512,9 +518,10 @@ def client(original_function):
result = None
litellm_call_id = str(uuid.uuid4())
kwargs["litellm_call_id"] = litellm_call_id
logging_obj = function_setup(start_time, *args, **kwargs)
kwargs["litellm_logging_obj"] = logging_obj
check_args()
try:
logging_obj = function_setup(start_time, *args, **kwargs)
kwargs["litellm_logging_obj"] = logging_obj
# [OPTIONAL] CHECK CACHE
# remove this after deprecating litellm.caching
if (litellm.caching or litellm.caching_with_models) and litellm.cache is None:

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "0.1.574"
version = "0.1.575"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"