diff --git a/litellm/main.py b/litellm/main.py index 656fca89c..137b5b0d8 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1357,14 +1357,17 @@ def batch_completion_models_all_responses(*args, **kwargs): def embedding( model, input=[], - api_key=None, - api_base=None, # Optional params azure=False, force_timeout=60, litellm_call_id=None, litellm_logging_obj=None, logger_fn=None, + # set api_base, api_version, api_key + api_base: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + api_type: Optional[str] = None, caching=False, custom_llm_provider=None, ): @@ -1375,10 +1378,26 @@ def embedding( logging.update_environment_variables(model=model, user="", optional_params={}, litellm_params={"force_timeout": force_timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn}) if azure == True or custom_llm_provider == "azure": # azure configs - openai.api_type = get_secret("AZURE_API_TYPE") or "azure" - openai.api_base = get_secret("AZURE_API_BASE") - openai.api_version = get_secret("AZURE_API_VERSION") - openai.api_key = get_secret("AZURE_API_KEY") + api_type = get_secret("AZURE_API_TYPE") or "azure" + + api_base = ( + api_base + or litellm.api_base + or get_secret("AZURE_API_BASE") + ) + + api_version = ( + api_version or + litellm.api_version or + get_secret("AZURE_API_VERSION") + ) + + api_key = ( + api_key or + litellm.api_key or + litellm.azure_key or + get_secret("AZURE_API_KEY") + ) ## LOGGING logging.pre_call( input=input, @@ -1390,30 +1409,61 @@ def embedding( }, ) ## EMBEDDING CALL - response = openai.Embedding.create(input=input, engine=model) + response = openai.Embedding.create( + input=input, + engine=model, + api_key=api_key, + api_base=api_base, + api_version=api_version, + api_type=api_type, + ) ## LOGGING logging.post_call(input=input, api_key=openai.api_key, original_response=response) elif model in litellm.open_ai_embedding_models: - openai.api_type = "openai" - openai.api_base = "https://api.openai.com/v1" - openai.api_version = None - openai.api_key = get_secret("OPENAI_API_KEY") + api_base = ( + api_base + or litellm.api_base + or get_secret("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + openai.organization = ( + litellm.organization + or get_secret("OPENAI_ORGANIZATION") + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + api_key or + litellm.api_key or + litellm.openai_key or + get_secret("OPENAI_API_KEY") + ) + api_type = "openai" + api_version = None + ## LOGGING logging.pre_call( input=input, - api_key=openai.api_key, + api_key=api_key, additional_args={ - "api_type": openai.api_type, - "api_base": openai.api_base, - "api_version": openai.api_version, + "api_type": api_type, + "api_base": api_base, + "api_version": api_version, }, ) ## EMBEDDING CALL - response = openai.Embedding.create(input=input, model=model) + response = openai.Embedding.create( + input=input, + model=model, + api_key=api_key, + api_base=api_base, + api_version=api_version, + api_type=api_type, + ) ## LOGGING - logging.post_call(input=input, api_key=openai.api_key, original_response=response) + logging.post_call(input=input, api_key=api_key, original_response=response) elif model in litellm.cohere_embedding_models: cohere_key = ( api_key