forked from phoenix/litellm-mirror
(feat) set api_base, api_key, api_version for embedding()
This commit is contained in:
parent
b77492d574
commit
b72dbe61c0
1 changed files with 67 additions and 17 deletions
|
@ -1357,14 +1357,17 @@ def batch_completion_models_all_responses(*args, **kwargs):
|
||||||
def embedding(
|
def embedding(
|
||||||
model,
|
model,
|
||||||
input=[],
|
input=[],
|
||||||
api_key=None,
|
|
||||||
api_base=None,
|
|
||||||
# Optional params
|
# Optional params
|
||||||
azure=False,
|
azure=False,
|
||||||
force_timeout=60,
|
force_timeout=60,
|
||||||
litellm_call_id=None,
|
litellm_call_id=None,
|
||||||
litellm_logging_obj=None,
|
litellm_logging_obj=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
# set api_base, api_version, api_key
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
api_version: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_type: Optional[str] = None,
|
||||||
caching=False,
|
caching=False,
|
||||||
custom_llm_provider=None,
|
custom_llm_provider=None,
|
||||||
):
|
):
|
||||||
|
@ -1375,10 +1378,26 @@ def embedding(
|
||||||
logging.update_environment_variables(model=model, user="", optional_params={}, litellm_params={"force_timeout": force_timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn})
|
logging.update_environment_variables(model=model, user="", optional_params={}, litellm_params={"force_timeout": force_timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn})
|
||||||
if azure == True or custom_llm_provider == "azure":
|
if azure == True or custom_llm_provider == "azure":
|
||||||
# azure configs
|
# azure configs
|
||||||
openai.api_type = get_secret("AZURE_API_TYPE") or "azure"
|
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||||
openai.api_base = get_secret("AZURE_API_BASE")
|
|
||||||
openai.api_version = get_secret("AZURE_API_VERSION")
|
api_base = (
|
||||||
openai.api_key = get_secret("AZURE_API_KEY")
|
api_base
|
||||||
|
or litellm.api_base
|
||||||
|
or get_secret("AZURE_API_BASE")
|
||||||
|
)
|
||||||
|
|
||||||
|
api_version = (
|
||||||
|
api_version or
|
||||||
|
litellm.api_version or
|
||||||
|
get_secret("AZURE_API_VERSION")
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = (
|
||||||
|
api_key or
|
||||||
|
litellm.api_key or
|
||||||
|
litellm.azure_key or
|
||||||
|
get_secret("AZURE_API_KEY")
|
||||||
|
)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.pre_call(
|
logging.pre_call(
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -1390,30 +1409,61 @@ def embedding(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
## EMBEDDING CALL
|
## EMBEDDING CALL
|
||||||
response = openai.Embedding.create(input=input, engine=model)
|
response = openai.Embedding.create(
|
||||||
|
input=input,
|
||||||
|
engine=model,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
api_type=api_type,
|
||||||
|
)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(input=input, api_key=openai.api_key, original_response=response)
|
logging.post_call(input=input, api_key=openai.api_key, original_response=response)
|
||||||
elif model in litellm.open_ai_embedding_models:
|
elif model in litellm.open_ai_embedding_models:
|
||||||
openai.api_type = "openai"
|
api_base = (
|
||||||
openai.api_base = "https://api.openai.com/v1"
|
api_base
|
||||||
openai.api_version = None
|
or litellm.api_base
|
||||||
openai.api_key = get_secret("OPENAI_API_KEY")
|
or get_secret("OPENAI_API_BASE")
|
||||||
|
or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
openai.organization = (
|
||||||
|
litellm.organization
|
||||||
|
or get_secret("OPENAI_ORGANIZATION")
|
||||||
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
|
)
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
api_key or
|
||||||
|
litellm.api_key or
|
||||||
|
litellm.openai_key or
|
||||||
|
get_secret("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
api_type = "openai"
|
||||||
|
api_version = None
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.pre_call(
|
logging.pre_call(
|
||||||
input=input,
|
input=input,
|
||||||
api_key=openai.api_key,
|
api_key=api_key,
|
||||||
additional_args={
|
additional_args={
|
||||||
"api_type": openai.api_type,
|
"api_type": api_type,
|
||||||
"api_base": openai.api_base,
|
"api_base": api_base,
|
||||||
"api_version": openai.api_version,
|
"api_version": api_version,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
## EMBEDDING CALL
|
## EMBEDDING CALL
|
||||||
response = openai.Embedding.create(input=input, model=model)
|
response = openai.Embedding.create(
|
||||||
|
input=input,
|
||||||
|
model=model,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
api_type=api_type,
|
||||||
|
)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(input=input, api_key=openai.api_key, original_response=response)
|
logging.post_call(input=input, api_key=api_key, original_response=response)
|
||||||
elif model in litellm.cohere_embedding_models:
|
elif model in litellm.cohere_embedding_models:
|
||||||
cohere_key = (
|
cohere_key = (
|
||||||
api_key
|
api_key
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue