test(test_proxy.py): add tests for debugging

This commit is contained in:
Krrish Dholakia 2023-10-17 16:47:48 -07:00
parent 22937b3b16
commit b67fe857dd
2 changed files with 28 additions and 16 deletions

View file

@ -278,6 +278,8 @@ def initialize(model, alias, api_base, debug, temperature, max_tokens, max_budge
if max_budget: # litellm-specific param if max_budget: # litellm-specific param
litellm.max_budget = max_budget litellm.max_budget = max_budget
dynamic_config["general"]["max_budget"] = max_budget dynamic_config["general"]["max_budget"] = max_budget
if debug: # litellm-specific param
litellm.set_verbose = True
if save: if save:
save_params_to_config(dynamic_config) save_params_to_config(dynamic_config)
with open(user_config_path) as f: with open(user_config_path) as f:
@ -384,24 +386,25 @@ def logger(
thread = threading.Thread(target=write_to_log, daemon=True) thread = threading.Thread(target=write_to_log, daemon=True)
thread.start() thread.start()
elif log_event_type == 'post_api_call': ## Commenting out post-api call logging as it would break json writes on cli error
if "stream" not in kwargs["optional_params"] or kwargs["optional_params"]["stream"] is False or kwargs.get( # elif log_event_type == 'post_api_call':
"complete_streaming_response", False): # if "stream" not in kwargs["optional_params"] or kwargs["optional_params"]["stream"] is False or kwargs.get(
inference_params = copy.deepcopy(kwargs) # "complete_streaming_response", False):
timestamp = inference_params.pop('start_time') # inference_params = copy.deepcopy(kwargs)
dt_key = timestamp.strftime("%Y%m%d%H%M%S%f")[:23] # timestamp = inference_params.pop('start_time')
# dt_key = timestamp.strftime("%Y%m%d%H%M%S%f")[:23]
with open(log_file, 'r') as f: # with open(log_file, 'r') as f:
existing_data = json.load(f) # existing_data = json.load(f)
existing_data[dt_key]['post_api_call'] = inference_params # existing_data[dt_key]['post_api_call'] = inference_params
def write_to_log(): # def write_to_log():
with open(log_file, 'w') as f: # with open(log_file, 'w') as f:
json.dump(existing_data, f, indent=2) # json.dump(existing_data, f, indent=2)
thread = threading.Thread(target=write_to_log, daemon=True) # thread = threading.Thread(target=write_to_log, daemon=True)
thread.start() # thread.start()
except: except:
pass pass

View file

@ -13,7 +13,7 @@ from click.testing import CliRunner
import pytest import pytest
import litellm import litellm
from litellm.proxy.llm import litellm_completion from litellm.proxy.llm import litellm_completion
from litellm.proxy.proxy_server import initialize
def test_azure_call(): def test_azure_call():
try: try:
data = { data = {
@ -25,4 +25,13 @@ def test_azure_call():
except Exception as e: except Exception as e:
pytest.fail(f"An error occurred: {e}") pytest.fail(f"An error occurred: {e}")
test_azure_call() ## test debug
def test_debug():
try:
initialize(model=None, alias=None, api_base=None, debug=True, temperature=None, max_tokens=None, max_budget=None, telemetry=None, drop_params=None, add_function_to_prompt=None, headers=None, save=None)
assert litellm.set_verbose == True
except Exception as e:
pytest.fail(f"An error occurred: {e}")
# test_debug()
## test logs