(fix) deepinfra with openai v1.0.0

This commit is contained in:
ishaan-jaff 2023-11-13 09:51:22 -08:00
parent cf0ab7155e
commit 27cbd7d895
2 changed files with 5 additions and 55 deletions

View file

@ -812,61 +812,6 @@ def completion(
response = CustomStreamWrapper(model_response, model, custom_llm_provider="maritalk", logging_obj=logging) response = CustomStreamWrapper(model_response, model, custom_llm_provider="maritalk", logging_obj=logging)
return response return response
response = model_response response = model_response
elif custom_llm_provider == "deepinfra": # for now this NEEDS to be above Hugging Face otherwise all calls to meta-llama/Llama-2-70b-chat-hf go to hf, we need this to go to deep infra if user sets provider to deep infra
# this can be called with the openai python package
api_key = (
api_key or
litellm.api_key or
litellm.openai_key or
get_secret("DEEPINFRA_API_KEY")
)
api_base = (
api_base
or litellm.api_base
or get_secret("DEEPINFRA_API_BASE")
or "https://api.deepinfra.com/v1/openai"
)
headers = (
headers or
litellm.headers
)
## LOGGING
logging.pre_call(
input=messages,
api_key=api_key,
)
## COMPLETION CALL
openai.api_key = api_key # set key for deep infra
openai.base_url = api_base # use the deepinfra api base
try:
response = openai.chat.completions.create(
model=model, # type: ignore
messages=messages, # type: ignore
api_type="openai", # type: ignore
api_version=api_version, # type: ignore
**optional_params, # type: ignore
)
except Exception as e:
## LOGGING - log the original exception returned
logging.post_call(
input=messages,
api_key=api_key,
original_response=str(e),
)
raise e
if "stream" in optional_params and optional_params["stream"] == True:
response = CustomStreamWrapper(response, model, custom_llm_provider="openai", logging_obj=logging)
return response
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
additional_args={"headers": headers},
)
elif ( elif (
custom_llm_provider == "huggingface" custom_llm_provider == "huggingface"
): ):

View file

@ -1992,6 +1992,11 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
api_base = "https://api.endpoints.anyscale.com/v1" api_base = "https://api.endpoints.anyscale.com/v1"
dynamic_api_key = os.getenv("ANYSCALE_API_KEY") dynamic_api_key = os.getenv("ANYSCALE_API_KEY")
custom_llm_provider = "custom_openai" custom_llm_provider = "custom_openai"
elif custom_llm_provider == "deepinfra":
# deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
api_base = "https://api.deepinfra.com/v1/openai"
dynamic_api_key = os.getenv("DEEPINFRA_API_KEY")
custom_llm_provider = "custom_openai"
return model, custom_llm_provider, dynamic_api_key, api_base return model, custom_llm_provider, dynamic_api_key, api_base
# check if api base is a known openai compatible endpoint # check if api base is a known openai compatible endpoint