From 72c1b5dcfcc6a0c59c3f6e18ab89854be965112b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 12 Aug 2023 16:40:36 -0700 Subject: [PATCH] replacing individual provider flags with 'custom_llm_provider' --- litellm/main.py | 10 +++++----- litellm/tests/test_exceptions.py | 14 +++++++------- litellm/utils.py | 12 ++++-------- 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index c1103e40b..b0215357e 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -285,7 +285,7 @@ def completion( completion_response = response[0].text ## LOGGING - logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) + logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len(encoding.encode(completion_response)) ## RESPONSE OBJECT @@ -306,11 +306,11 @@ def completion( prompt = " ".join([message["content"] for message in messages]) ## LOGGING - logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn) + logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn) input_payload = {"inputs": prompt} response = requests.post(API_URL, headers=headers, json=input_payload) ## LOGGING - logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": response.text}, logger_fn=logger_fn) + logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": response.text}, logger_fn=logger_fn) completion_response = response.json()[0]['generated_text'] prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len(encoding.encode(completion_response)) @@ -332,7 +332,7 @@ def completion( prompt = " ".join([message["content"] for message in messages]) # TODO: Add chat support for together AI ## LOGGING - logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn) + logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn) res = requests.post(endpoint, json={ "model": model, "prompt": prompt, @@ -342,7 +342,7 @@ def completion( headers=headers ) ## LOGGING - logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": res.text}, logger_fn=logger_fn) + logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": res.text}, logger_fn=logger_fn) if stream == True: response = CustomStreamWrapper(res, "together_ai") return response diff --git a/litellm/tests/test_exceptions.py b/litellm/tests/test_exceptions.py index 31e8aa1ac..fa62b1970 100644 --- a/litellm/tests/test_exceptions.py +++ b/litellm/tests/test_exceptions.py @@ -37,7 +37,7 @@ def test_context_window(model): try: azure = model == "chatgpt-test" print(f"model: {model}") - response = completion(model=model, messages=messages, azure=azure, logger_fn=logging_fn) + response = completion(model=model, messages=messages, custom_llm_provider=custom_llm_provider, logger_fn=logging_fn) print(f"response: {response}") except InvalidRequestError: print("InvalidRequestError") @@ -59,14 +59,14 @@ def invalid_auth(model): # set the model key to an invalid key, depending on the messages = [{ "content": "Hello, how are you?","role": "user"}] temporary_key = None try: - azure = False + custom_llm_provider = None if model == "gpt-3.5-turbo": temporary_key = os.environ["OPENAI_API_KEY"] os.environ["OPENAI_API_KEY"] = "bad-key" elif model == "chatgpt-test": temporary_key = os.environ["AZURE_API_KEY"] os.environ["AZURE_API_KEY"] = "bad-key" - azure = True + custom_llm_provider = "azure" elif model == "claude-instant-1": temporary_key = os.environ["ANTHROPIC_API_KEY"] os.environ["ANTHROPIC_API_KEY"] = "bad-key" @@ -77,7 +77,7 @@ def invalid_auth(model): # set the model key to an invalid key, depending on the temporary_key = os.environ["REPLICATE_API_KEY"] os.environ["REPLICATE_API_KEY"] = "bad-key" print(f"model: {model}") - response = completion(model=model, messages=messages, azure=azure) + response = completion(model=model, messages=messages, custom_llm_provider=custom_llm_provider) print(f"response: {response}") except AuthenticationError as e: print(f"AuthenticationError Caught Exception - {e}") @@ -107,11 +107,11 @@ invalid_auth("command-nightly") # try: # sample_text = "how does a court case get to the Supreme Court?" * 50000 # messages = [{ "content": sample_text,"role": "user"}] -# azure = False +# custom_llm_provider = None # if model == "chatgpt-test": -# azure = True +# custom_llm_provider = "azure" # print(f"model: {model}") -# response = completion(model=model, messages=messages, azure=azure) +# response = completion(model=model, messages=messages, custom_llm_provider=custom_llm_provider) # except RateLimitError: # return True # except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server diff --git a/litellm/utils.py b/litellm/utils.py index fc7600cb7..5dec0f220 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -75,10 +75,6 @@ def logging(model=None, input=None, custom_llm_provider=None, azure=False, addit model_call_details["custom_llm_provider"] = custom_llm_provider if exception: model_call_details["exception"] = exception - - # if litellm.telemetry: - # safe_crash_reporting(model=model, exception=exception, azure=azure) # log usage-crash details. Do not log any user details. If you want to turn this off, set `litellm.telemetry=False`. - if input: model_call_details["input"] = input @@ -134,8 +130,8 @@ def client(original_function): try: model = args[0] if len(args) > 0 else kwargs["model"] exception = kwargs["exception"] if "exception" in kwargs else None - azure = kwargs["azure"] if "azure" in kwargs else None - safe_crash_reporting(model=model, exception=exception, azure=azure) # log usage-crash details. Do not log any user details. If you want to turn this off, set `litellm.telemetry=False`. + custom_llm_provider = kwargs["custom_llm_provider"] if "custom_llm_provider" in kwargs else None + safe_crash_reporting(model=model, exception=exception, custom_llm_provider=custom_llm_provider) # log usage-crash details. Do not log any user details. If you want to turn this off, set `litellm.telemetry=False`. except: #[Non-Blocking Error] pass @@ -647,11 +643,11 @@ def exception_type(model, original_exception): else: # don't let an error with mapping interrupt the user from receiving an error from the llm api calls raise original_exception -def safe_crash_reporting(model=None, exception=None, azure=None): +def safe_crash_reporting(model=None, exception=None, custom_llm_provider=None): data = { "model": model, "exception": str(exception), - "azure": azure + "custom_llm_provider": custom_llm_provider } threading.Thread(target=litellm_telemetry, args=(data,)).start()