thread safe azure

This commit is contained in:
ishaan-jaff 2023-09-02 18:05:05 -07:00
parent 96fce0a880
commit 67d881acb4

View file

@ -163,11 +163,13 @@ def completion(
if custom_llm_provider == "azure": if custom_llm_provider == "azure":
# azure configs # azure configs
openai.api_type = "azure" openai.api_type = "azure"
openai.api_base = (
litellm.api_base api_base = (
if litellm.api_base is not None api_base
else get_secret("AZURE_API_BASE") or litellm.api_base
or get_secret("AZURE_API_BASE")
) )
openai.api_version = ( openai.api_version = (
litellm.api_version litellm.api_version
if litellm.api_version is not None if litellm.api_version is not None
@ -177,8 +179,7 @@ def completion(
api_key = litellm.azure_key api_key = litellm.azure_key
elif not api_key and get_secret("AZURE_API_KEY"): elif not api_key and get_secret("AZURE_API_KEY"):
api_key = get_secret("AZURE_API_KEY") api_key = get_secret("AZURE_API_KEY")
# set key
openai.api_key = api_key
## LOGGING ## LOGGING
logging.pre_call( logging.pre_call(
input=messages, input=messages,
@ -190,19 +191,14 @@ def completion(
}, },
) )
## COMPLETION CALL ## COMPLETION CALL
if litellm.headers: response = openai.ChatCompletion.create(
response = openai.ChatCompletion.create( engine=model,
engine=model, messages=messages,
messages=messages, headers=litellm.headers,
headers=litellm.headers, api_key=api_key,
**optional_params, api_base=api_base,
) **optional_params,
else: )
response = openai.ChatCompletion.create(
engine=model,
messages=messages,
**optional_params
)
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
response = CustomStreamWrapper(response, model, logging_obj=logging) response = CustomStreamWrapper(response, model, logging_obj=logging)
return response return response
@ -253,7 +249,7 @@ def completion(
headers=litellm.headers, # None by default headers=litellm.headers, # None by default
api_base=api_base, # thread safe setting base, key, api_version api_base=api_base, # thread safe setting base, key, api_version
api_key=api_key, api_key=api_key,
api_version=api_version # default None api_version=api_version, # default None
**optional_params, **optional_params,
) )
except Exception as e: except Exception as e: