mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
add cohere embedding models
This commit is contained in:
parent
8213e23970
commit
af914d4be1
5 changed files with 127 additions and 9 deletions
|
@ -54,6 +54,7 @@ from litellm.utils import (
|
|||
get_secret,
|
||||
CustomStreamWrapper,
|
||||
ModelResponse,
|
||||
EmbeddingResponse,
|
||||
read_config_args,
|
||||
)
|
||||
|
||||
|
@ -1352,7 +1353,15 @@ def batch_completion_models_all_responses(*args, **kwargs):
|
|||
60
|
||||
) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
|
||||
def embedding(
|
||||
model, input=[], azure=False, force_timeout=60, litellm_call_id=None, litellm_logging_obj=None, logger_fn=None, caching=False,
|
||||
model,
|
||||
input=[],
|
||||
azure=False,
|
||||
force_timeout=60,
|
||||
litellm_call_id=None,
|
||||
litellm_logging_obj=None,
|
||||
logger_fn=None,
|
||||
caching=False,
|
||||
api_key=None,
|
||||
):
|
||||
try:
|
||||
response = None
|
||||
|
@ -1393,6 +1402,23 @@ def embedding(
|
|||
)
|
||||
## EMBEDDING CALL
|
||||
response = openai.Embedding.create(input=input, model=model)
|
||||
elif model in litellm.cohere_embedding_models:
|
||||
cohere_key = (
|
||||
api_key
|
||||
or litellm.cohere_key
|
||||
or get_secret("COHERE_API_KEY")
|
||||
or get_secret("CO_API_KEY")
|
||||
or litellm.api_key
|
||||
)
|
||||
response = cohere.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
encoding=encoding,
|
||||
api_key=cohere_key,
|
||||
logging_obj=logging,
|
||||
model_response= EmbeddingResponse()
|
||||
|
||||
)
|
||||
else:
|
||||
args = locals()
|
||||
raise ValueError(f"No valid embedding model args passed in - {args}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue