diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 3fbbdea0e..54bc53326 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -442,9 +442,46 @@ def test_completion_text_openai(): pytest.fail(f"Error occurred: {e}") # test_completion_text_openai() +def custom_callback( + kwargs, # kwargs to completion + completion_response, # response from completion + start_time, end_time # start/end time +): + # Your custom code here + try: + print("LITELLM: in custom callback function") + print("\nkwargs\n", kwargs) + model = kwargs["model"] + messages = kwargs["messages"] + user = kwargs.get("user") + + ################################################# + + print( + f""" + Model: {model}, + Messages: {messages}, + User: {user}, + Seed: {kwargs["seed"]}, + temperature: {kwargs["temperature"]}, + """ + ) + + assert kwargs["user"] == "ishaans app" + assert kwargs["model"] == "gpt-3.5-turbo-1106" + assert kwargs["seed"] == 12 + assert kwargs["temperature"] == 0.5 + except Exception as e: + pytest.fail(f"Error occurred: {e}") + def test_completion_openai_with_optional_params(): + # [Proxy PROD TEST] WARNING: DO NOT DELETE THIS TEST + # assert that `user` gets passed to the completion call + # Note: This tests that we actually send the optional params to the completion call + # We use custom callbacks to test this try: litellm.set_verbose = True + litellm.success_callback = [custom_callback] response = completion( model="gpt-3.5-turbo-1106", messages=[ @@ -458,15 +495,17 @@ def test_completion_openai_with_optional_params(): seed=12, response_format={ "type": "json_object" }, logit_bias=None, + user = "ishaans app" ) # Add any assertions here to check the response + print(response) - except litellm.Timeout as e: - pass + litellm.success_callback = [] # unset callbacks + except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_openai_with_optional_params() +test_completion_openai_with_optional_params() def test_completion_openai_litellm_key(): try: @@ -1337,7 +1376,7 @@ def test_azure_cloudflare_api(): traceback.print_exc() pass -test_azure_cloudflare_api() +# test_azure_cloudflare_api() def test_completion_anyscale_2(): try: diff --git a/litellm/utils.py b/litellm/utils.py index 280a6342f..b756fc358 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -544,7 +544,8 @@ class Logging: "optional_params": self.optional_params, "litellm_params": self.litellm_params, "start_time": self.start_time, - "stream": self.stream + "stream": self.stream, + **self.optional_params } def pre_call(self, input, api_key, model=None, additional_args={}):