(feat) embedding() remove junk params

This commit is contained in:
ishaan-jaff 2023-11-22 14:03:41 -08:00
parent 0b4e10e068
commit b3bca98561

View file

@ -1672,11 +1672,7 @@ def embedding(
model, model,
input=[], input=[],
# Optional params # Optional params
azure=False, timeout=600, # default to 10 minutes
force_timeout=60,
litellm_call_id=None,
litellm_logging_obj=None,
logger_fn=None,
# set api_base, api_version, api_key # set api_base, api_version, api_key
api_base: Optional[str] = None, api_base: Optional[str] = None,
api_version: Optional[str] = None, api_version: Optional[str] = None,
@ -1684,6 +1680,9 @@ def embedding(
api_type: Optional[str] = None, api_type: Optional[str] = None,
caching: bool=False, caching: bool=False,
custom_llm_provider=None, custom_llm_provider=None,
litellm_call_id=None,
litellm_logging_obj=None,
logger_fn=None,
**kwargs **kwargs
): ):
""" """
@ -1692,8 +1691,7 @@ def embedding(
Parameters: Parameters:
- model: The embedding model to use. - model: The embedding model to use.
- input: The input for which embeddings are to be generated. - input: The input for which embeddings are to be generated.
- azure: A boolean indicating whether to use the Azure API for embedding. - timeout: The timeout value for the API call, default 10 mins
- force_timeout: The timeout value for the API call.
- litellm_call_id: The call ID for litellm logging. - litellm_call_id: The call ID for litellm logging.
- litellm_logging_obj: The litellm logging object. - litellm_logging_obj: The litellm logging object.
- logger_fn: The logger function. - logger_fn: The logger function.
@ -1710,11 +1708,12 @@ def embedding(
Raises: Raises:
- exception_type: If an exception occurs during the API call. - exception_type: If an exception occurs during the API call.
""" """
azure = kwargs.get("azure", None)
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base)
try: try:
response = None response = None
logging = litellm_logging_obj logging = litellm_logging_obj
logging.update_environment_variables(model=model, user="", optional_params={}, litellm_params={"force_timeout": force_timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn}) logging.update_environment_variables(model=model, user="", optional_params={}, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn})
if azure == True or custom_llm_provider == "azure": if azure == True or custom_llm_provider == "azure":
# azure configs # azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure" api_type = get_secret("AZURE_API_TYPE") or "azure"
@ -1751,8 +1750,9 @@ def embedding(
api_version=api_version, api_version=api_version,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token,
logging_obj=logging, logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
optional_params=kwargs optional_params=kwargs,
) )
elif model in litellm.open_ai_embedding_models: elif model in litellm.open_ai_embedding_models:
api_base = ( api_base = (
@ -1784,6 +1784,7 @@ def embedding(
api_base=api_base, api_base=api_base,
api_key=api_key, api_key=api_key,
logging_obj=logging, logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
optional_params=kwargs optional_params=kwargs
) )