thread safe version, key and base for openai

This commit is contained in:
ishaan-jaff 2023-09-02 17:49:25 -07:00
parent f99f2a3a4d
commit 1bc2a6d5cc

View file

@ -87,6 +87,7 @@ def completion(
*, *,
return_async=False, return_async=False,
api_key=None, api_key=None,
api_version=None,
force_timeout=600, force_timeout=600,
logger_fn=None, logger_fn=None,
verbose=False, verbose=False,
@ -198,7 +199,9 @@ def completion(
) )
else: else:
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
engine=model, messages=messages, **optional_params engine=model,
messages=messages,
**optional_params
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
response = CustomStreamWrapper(response, model, logging_obj=logging) response = CustomStreamWrapper(response, model, logging_obj=logging)
@ -228,7 +231,6 @@ def completion(
or get_secret("OPENAI_API_BASE") or get_secret("OPENAI_API_BASE")
or "https://api.openai.com/v1" or "https://api.openai.com/v1"
) )
openai.api_version = None
if litellm.organization: if litellm.organization:
openai.organization = litellm.organization openai.organization = litellm.organization
# set API KEY # set API KEY
@ -237,8 +239,6 @@ def completion(
elif not api_key and get_secret("OPENAI_API_KEY"): elif not api_key and get_secret("OPENAI_API_KEY"):
api_key = get_secret("OPENAI_API_KEY") api_key = get_secret("OPENAI_API_KEY")
openai.api_key = api_key
## LOGGING ## LOGGING
logging.pre_call( logging.pre_call(
input=messages, input=messages,
@ -247,21 +247,15 @@ def completion(
) )
## COMPLETION CALL ## COMPLETION CALL
try: try:
if litellm.headers: response = openai.ChatCompletion.create(
response = openai.ChatCompletion.create( model=model,
model=model, messages=messages,
messages=messages, headers=litellm.headers, # None by default
headers=litellm.headers, api_base=api_base, # thread safe setting base, key, api_version
api_base=api_base, api_key=api_key,
**optional_params, api_version=api_version # default None
) **optional_params,
else: )
response = openai.ChatCompletion.create(
model=model,
messages=messages,
api_base=api_base, # thread safe setting of api_base
**optional_params
)
except Exception as e: except Exception as e:
## LOGGING - log the original exception returned ## LOGGING - log the original exception returned
logging.post_call( logging.post_call(