From af914d4be14b623d366b8e2859db8754401baa4f Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Fri, 29 Sep 2023 09:59:31 -0700 Subject: [PATCH] add cohere embedding models --- litellm/__init__.py | 1 + litellm/llms/cohere.py | 74 +++++++++++++++++++++++++++++++-- litellm/main.py | 28 ++++++++++++- litellm/tests/test_embedding.py | 20 ++++++--- litellm/utils.py | 13 ++++++ 5 files changed, 127 insertions(+), 9 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index a357292c9..d806f534f 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -286,6 +286,7 @@ longer_context_model_fallback_dict: dict = { ####### EMBEDDING MODELS ################### open_ai_embedding_models: List = ["text-embedding-ada-002"] +cohere_embedding_models: List = ["embed-english-v2.0", "embed-english-light-v2.0", "embed-multilingual-v2.0"] from .timeout import timeout from .testing import * diff --git a/litellm/llms/cohere.py b/litellm/llms/cohere.py index 113b6b542..7982d94f4 100644 --- a/litellm/llms/cohere.py +++ b/litellm/llms/cohere.py @@ -96,6 +96,74 @@ def completion( } return model_response -def embedding(): - # logic for parsing in - calling - parsing out model embedding calls - pass +def embedding( + model: str, + input: list, + api_key: str, + logging_obj=None, + model_response=None, + encoding=None, +): + headers = validate_environment(api_key) + embed_url = "https://api.cohere.ai/v1/embed" + model = model + data = { + "model": model, + "texts": input, + } + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) + ## COMPLETION CALL + response = requests.post( + embed_url, headers=headers, data=json.dumps(data) + ) + + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response, + ) + # print(response.json()) + """ + response + { + 'object': "list", + 'data': [ + + ] + 'model', + 'usage' + } + """ + embeddings = response.json()['embeddings'] + output_data = [] + for idx, embedding in enumerate(embeddings): + output_data.append( + { + "object": "embedding", + "index": idx, + "embedding": embedding + } + ) + model_response["object"] = "list" + model_response["data"] = output_data + model_response["model"] = model + input_tokens = 0 + for text in input: + input_tokens+=len(encoding.encode(text)) + + model_response["usage"] = { + "prompt_tokens": input_tokens, + "total_tokens": input_tokens, + } + return model_response + + + \ No newline at end of file diff --git a/litellm/main.py b/litellm/main.py index 08c4758de..fafaf5657 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -54,6 +54,7 @@ from litellm.utils import ( get_secret, CustomStreamWrapper, ModelResponse, + EmbeddingResponse, read_config_args, ) @@ -1352,7 +1353,15 @@ def batch_completion_models_all_responses(*args, **kwargs): 60 ) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout` def embedding( - model, input=[], azure=False, force_timeout=60, litellm_call_id=None, litellm_logging_obj=None, logger_fn=None, caching=False, + model, + input=[], + azure=False, + force_timeout=60, + litellm_call_id=None, + litellm_logging_obj=None, + logger_fn=None, + caching=False, + api_key=None, ): try: response = None @@ -1393,6 +1402,23 @@ def embedding( ) ## EMBEDDING CALL response = openai.Embedding.create(input=input, model=model) + elif model in litellm.cohere_embedding_models: + cohere_key = ( + api_key + or litellm.cohere_key + or get_secret("COHERE_API_KEY") + or get_secret("CO_API_KEY") + or litellm.api_key + ) + response = cohere.embedding( + model=model, + input=input, + encoding=encoding, + api_key=cohere_key, + logging_obj=logging, + model_response= EmbeddingResponse() + + ) else: args = locals() raise ValueError(f"No valid embedding model args passed in - {args}") diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index cde73afc9..eea95e395 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -8,18 +8,28 @@ sys.path.insert( import litellm from litellm import embedding, completion -litellm.set_verbose = True - +litellm.set_verbose = False def test_openai_embedding(): try: response = embedding( - model="text-embedding-ada-002", input=["good morning from litellm"] + model="text-embedding-ada-002", input=["good morning from litellm", "this is another item"] ) + print(response) # Add any assertions here to check the response # print(f"response: {str(response)}") except Exception as e: pytest.fail(f"Error occurred: {e}") - - # test_openai_embedding() + +def test_cohere_embedding(): + try: + response = embedding( + model="embed-english-v2.0", input=["good morning from litellm", "this is another item"] + ) + print(f"response:", response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + +# test_cohere_embedding() + diff --git a/litellm/utils.py b/litellm/utils.py index 6ba4cbe3c..94ede879c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -173,6 +173,19 @@ class ModelResponse(OpenAIObject): d["choices"] = [choice.to_dict_recursive() for choice in self.choices] return d +class EmbeddingResponse(OpenAIObject): + def __init__(self, id=None, choices=None, created=None, model=None, usage=None, stream=False, response_ms=None, **params): + self.object = "list" + if response_ms: + self._response_ms = response_ms + else: + self._response_ms = None + self.data = [] + self.model = model + + def to_dict_recursive(self): + d = super().to_dict_recursive() + return d ############################################################ def print_verbose(print_statement):