diff --git a/litellm/llms/cohere.py b/litellm/llms/cohere.py index 0cf8949f70..734dfa8da6 100644 --- a/litellm/llms/cohere.py +++ b/litellm/llms/cohere.py @@ -195,6 +195,7 @@ def embedding( logging_obj=None, model_response=None, encoding=None, + optional_params=None, ): headers = validate_environment(api_key) embed_url = "https://api.cohere.ai/v1/embed" @@ -202,8 +203,13 @@ def embedding( data = { "model": model, "texts": input, + **optional_params } + if "3" in model and "input_type" not in data: + # cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document" + data["input_type"] = "search_document" + ## LOGGING logging_obj.pre_call( input=input, diff --git a/litellm/main.py b/litellm/main.py index 90c9b37e69..b4ecc67d4b 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1764,6 +1764,7 @@ def embedding( response = cohere.embedding( model=model, input=input, + optional_params=kwargs, encoding=encoding, api_key=cohere_key, logging_obj=logging,