From ed2d461770d2a8bb1bc5baeb9eb99771a252acd5 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 25 Sep 2023 21:42:45 -0700 Subject: [PATCH] add beta deep infra support --- litellm/__init__.py | 1 + litellm/main.py | 42 ++++++++++++++++++++++++++++++++ litellm/tests/test_completion.py | 17 +++++++++++-- 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index df81cf7b3..e99c8cf9e 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -241,6 +241,7 @@ provider_list: List = [ "petals", "oobabooga", "ollama", + "deepinfra", "custom", # custom apis ] diff --git a/litellm/main.py b/litellm/main.py index cc4015b4c..df4047c74 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -663,6 +663,48 @@ def completion( response = CustomStreamWrapper(model_response, model, custom_llm_provider="cohere", logging_obj=logging) return response response = model_response + elif custom_llm_provider == "deepinfra": # for know this NEEDS to be above Hugging Face otherwise all calls to meta-llama/Llama-2-70b-chat-hf go to hf, we need this to go to deep infra if user sets provider to deep infra + # this can be called with the openai python package + api_key = ( + api_key or + litellm.api_key or + litellm.openai_key or + get_secret("DEEPINFRA_API_KEY") + ) + ## LOGGING + logging.pre_call( + input=messages, + api_key=api_key, + ) + ## COMPLETION CALL + openai.api_key = api_key # set key for deep infra + try: + response = openai.ChatCompletion.create( + model=model, + messages=messages, + api_base="https://api.deepinfra.com/v1/openai", # use the deepinfra api base + api_type="openai", + api_version=api_version, # default None + **optional_params, + ) + except Exception as e: + ## LOGGING - log the original exception returned + logging.post_call( + input=messages, + api_key=api_key, + original_response=str(e), + ) + raise e + if "stream" in optional_params and optional_params["stream"] == True: + response = CustomStreamWrapper(response, model, custom_llm_provider="openai", logging_obj=logging) + return response + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={"headers": litellm.headers}, + ) elif ( ( model in litellm.huggingface_models and diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index dad571310..7007e923a 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -48,7 +48,7 @@ def test_completion_claude(): print(response.response_ms) except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_claude() +# test_completion_claude() def test_completion_claude_max_tokens(): try: @@ -67,7 +67,7 @@ def test_completion_claude_max_tokens(): print(response.response_ms) except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_claude_max_tokens() +# test_completion_claude_max_tokens() # def test_completion_oobabooga(): # try: @@ -854,6 +854,19 @@ def test_completion_ai21(): except Exception as e: pytest.fail(f"Error occurred: {e}") + +## test deep infra +def test_completion_deep_infra(): + # litellm.set_verbose = True + model_name = "deepinfra/meta-llama/Llama-2-70b-chat-hf" + try: + response = completion(model=model_name, messages=messages) + # Add any assertions here to check the response + print(response) + print(response.response_ms) + except Exception as e: + pytest.fail(f"Error occurred: {e}") +# test_completion_deep_infra() # test_completion_ai21() # test config file with completion # # def test_completion_openai_config():