add cohere embedding models

This commit is contained in:
ishaan-jaff 2023-09-29 09:59:31 -07:00
parent 8213e23970
commit af914d4be1
5 changed files with 127 additions and 9 deletions

View file

@ -286,6 +286,7 @@ longer_context_model_fallback_dict: dict = {
####### EMBEDDING MODELS ###################
open_ai_embedding_models: List = ["text-embedding-ada-002"]
cohere_embedding_models: List = ["embed-english-v2.0", "embed-english-light-v2.0", "embed-multilingual-v2.0"]
from .timeout import timeout
from .testing import *

View file

@ -96,6 +96,74 @@ def completion(
}
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass
def embedding(
model: str,
input: list,
api_key: str,
logging_obj=None,
model_response=None,
encoding=None,
):
headers = validate_environment(api_key)
embed_url = "https://api.cohere.ai/v1/embed"
model = model
data = {
"model": model,
"texts": input,
}
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
embed_url, headers=headers, data=json.dumps(data)
)
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
# print(response.json())
"""
response
{
'object': "list",
'data': [
]
'model',
'usage'
}
"""
embeddings = response.json()['embeddings']
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding
}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
input_tokens = 0
for text in input:
input_tokens+=len(encoding.encode(text))
model_response["usage"] = {
"prompt_tokens": input_tokens,
"total_tokens": input_tokens,
}
return model_response

View file

@ -54,6 +54,7 @@ from litellm.utils import (
get_secret,
CustomStreamWrapper,
ModelResponse,
EmbeddingResponse,
read_config_args,
)
@ -1352,7 +1353,15 @@ def batch_completion_models_all_responses(*args, **kwargs):
60
) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
def embedding(
model, input=[], azure=False, force_timeout=60, litellm_call_id=None, litellm_logging_obj=None, logger_fn=None, caching=False,
model,
input=[],
azure=False,
force_timeout=60,
litellm_call_id=None,
litellm_logging_obj=None,
logger_fn=None,
caching=False,
api_key=None,
):
try:
response = None
@ -1393,6 +1402,23 @@ def embedding(
)
## EMBEDDING CALL
response = openai.Embedding.create(input=input, model=model)
elif model in litellm.cohere_embedding_models:
cohere_key = (
api_key
or litellm.cohere_key
or get_secret("COHERE_API_KEY")
or get_secret("CO_API_KEY")
or litellm.api_key
)
response = cohere.embedding(
model=model,
input=input,
encoding=encoding,
api_key=cohere_key,
logging_obj=logging,
model_response= EmbeddingResponse()
)
else:
args = locals()
raise ValueError(f"No valid embedding model args passed in - {args}")

View file

@ -8,18 +8,28 @@ sys.path.insert(
import litellm
from litellm import embedding, completion
litellm.set_verbose = True
litellm.set_verbose = False
def test_openai_embedding():
try:
response = embedding(
model="text-embedding-ada-002", input=["good morning from litellm"]
model="text-embedding-ada-002", input=["good morning from litellm", "this is another item"]
)
print(response)
# Add any assertions here to check the response
# print(f"response: {str(response)}")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_openai_embedding()
def test_cohere_embedding():
try:
response = embedding(
model="embed-english-v2.0", input=["good morning from litellm", "this is another item"]
)
print(f"response:", response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_cohere_embedding()

View file

@ -173,6 +173,19 @@ class ModelResponse(OpenAIObject):
d["choices"] = [choice.to_dict_recursive() for choice in self.choices]
return d
class EmbeddingResponse(OpenAIObject):
def __init__(self, id=None, choices=None, created=None, model=None, usage=None, stream=False, response_ms=None, **params):
self.object = "list"
if response_ms:
self._response_ms = response_ms
else:
self._response_ms = None
self.data = []
self.model = model
def to_dict_recursive(self):
d = super().to_dict_recursive()
return d
############################################################
def print_verbose(print_statement):