(test) test completion: if 'user' passed to API

This commit is contained in:
ishaan-jaff 2023-12-04 09:45:03 -08:00
parent 31d9762b50
commit 93f5c266da
2 changed files with 45 additions and 5 deletions

View file

@ -442,9 +442,46 @@ def test_completion_text_openai():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_text_openai() # test_completion_text_openai()
def custom_callback(
kwargs, # kwargs to completion
completion_response, # response from completion
start_time, end_time # start/end time
):
# Your custom code here
try:
print("LITELLM: in custom callback function")
print("\nkwargs\n", kwargs)
model = kwargs["model"]
messages = kwargs["messages"]
user = kwargs.get("user")
#################################################
print(
f"""
Model: {model},
Messages: {messages},
User: {user},
Seed: {kwargs["seed"]},
temperature: {kwargs["temperature"]},
"""
)
assert kwargs["user"] == "ishaans app"
assert kwargs["model"] == "gpt-3.5-turbo-1106"
assert kwargs["seed"] == 12
assert kwargs["temperature"] == 0.5
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_openai_with_optional_params(): def test_completion_openai_with_optional_params():
# [Proxy PROD TEST] WARNING: DO NOT DELETE THIS TEST
# assert that `user` gets passed to the completion call
# Note: This tests that we actually send the optional params to the completion call
# We use custom callbacks to test this
try: try:
litellm.set_verbose = True litellm.set_verbose = True
litellm.success_callback = [custom_callback]
response = completion( response = completion(
model="gpt-3.5-turbo-1106", model="gpt-3.5-turbo-1106",
messages=[ messages=[
@ -458,15 +495,17 @@ def test_completion_openai_with_optional_params():
seed=12, seed=12,
response_format={ "type": "json_object" }, response_format={ "type": "json_object" },
logit_bias=None, logit_bias=None,
user = "ishaans app"
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)
except litellm.Timeout as e: litellm.success_callback = [] # unset callbacks
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_openai_with_optional_params() test_completion_openai_with_optional_params()
def test_completion_openai_litellm_key(): def test_completion_openai_litellm_key():
try: try:
@ -1337,7 +1376,7 @@ def test_azure_cloudflare_api():
traceback.print_exc() traceback.print_exc()
pass pass
test_azure_cloudflare_api() # test_azure_cloudflare_api()
def test_completion_anyscale_2(): def test_completion_anyscale_2():
try: try:

View file

@ -544,7 +544,8 @@ class Logging:
"optional_params": self.optional_params, "optional_params": self.optional_params,
"litellm_params": self.litellm_params, "litellm_params": self.litellm_params,
"start_time": self.start_time, "start_time": self.start_time,
"stream": self.stream "stream": self.stream,
**self.optional_params
} }
def pre_call(self, input, api_key, model=None, additional_args={}): def pre_call(self, input, api_key, model=None, additional_args={}):