add beta deep infra support

This commit is contained in:
ishaan-jaff 2023-09-25 21:42:45 -07:00
parent e9747fb763
commit ed2d461770
3 changed files with 58 additions and 2 deletions

View file

@ -241,6 +241,7 @@ provider_list: List = [
"petals",
"oobabooga",
"ollama",
"deepinfra",
"custom", # custom apis
]

View file

@ -663,6 +663,48 @@ def completion(
response = CustomStreamWrapper(model_response, model, custom_llm_provider="cohere", logging_obj=logging)
return response
response = model_response
elif custom_llm_provider == "deepinfra": # for know this NEEDS to be above Hugging Face otherwise all calls to meta-llama/Llama-2-70b-chat-hf go to hf, we need this to go to deep infra if user sets provider to deep infra
# this can be called with the openai python package
api_key = (
api_key or
litellm.api_key or
litellm.openai_key or
get_secret("DEEPINFRA_API_KEY")
)
## LOGGING
logging.pre_call(
input=messages,
api_key=api_key,
)
## COMPLETION CALL
openai.api_key = api_key # set key for deep infra
try:
response = openai.ChatCompletion.create(
model=model,
messages=messages,
api_base="https://api.deepinfra.com/v1/openai", # use the deepinfra api base
api_type="openai",
api_version=api_version, # default None
**optional_params,
)
except Exception as e:
## LOGGING - log the original exception returned
logging.post_call(
input=messages,
api_key=api_key,
original_response=str(e),
)
raise e
if "stream" in optional_params and optional_params["stream"] == True:
response = CustomStreamWrapper(response, model, custom_llm_provider="openai", logging_obj=logging)
return response
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
additional_args={"headers": litellm.headers},
)
elif (
(
model in litellm.huggingface_models and

View file

@ -48,7 +48,7 @@ def test_completion_claude():
print(response.response_ms)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_claude()
# test_completion_claude()
def test_completion_claude_max_tokens():
try:
@ -67,7 +67,7 @@ def test_completion_claude_max_tokens():
print(response.response_ms)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_claude_max_tokens()
# test_completion_claude_max_tokens()
# def test_completion_oobabooga():
# try:
@ -854,6 +854,19 @@ def test_completion_ai21():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
## test deep infra
def test_completion_deep_infra():
# litellm.set_verbose = True
model_name = "deepinfra/meta-llama/Llama-2-70b-chat-hf"
try:
response = completion(model=model_name, messages=messages)
# Add any assertions here to check the response
print(response)
print(response.response_ms)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_deep_infra()
# test_completion_ai21()
# test config file with completion #
# def test_completion_openai_config():