mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
add cohere embedding models
This commit is contained in:
parent
8213e23970
commit
af914d4be1
5 changed files with 127 additions and 9 deletions
|
@ -286,6 +286,7 @@ longer_context_model_fallback_dict: dict = {
|
|||
|
||||
####### EMBEDDING MODELS ###################
|
||||
open_ai_embedding_models: List = ["text-embedding-ada-002"]
|
||||
cohere_embedding_models: List = ["embed-english-v2.0", "embed-english-light-v2.0", "embed-multilingual-v2.0"]
|
||||
|
||||
from .timeout import timeout
|
||||
from .testing import *
|
||||
|
|
|
@ -96,6 +96,74 @@ def completion(
|
|||
}
|
||||
return model_response
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
def embedding(
|
||||
model: str,
|
||||
input: list,
|
||||
api_key: str,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
encoding=None,
|
||||
):
|
||||
headers = validate_environment(api_key)
|
||||
embed_url = "https://api.cohere.ai/v1/embed"
|
||||
model = model
|
||||
data = {
|
||||
"model": model,
|
||||
"texts": input,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = requests.post(
|
||||
embed_url, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response,
|
||||
)
|
||||
# print(response.json())
|
||||
"""
|
||||
response
|
||||
{
|
||||
'object': "list",
|
||||
'data': [
|
||||
|
||||
]
|
||||
'model',
|
||||
'usage'
|
||||
}
|
||||
"""
|
||||
embeddings = response.json()['embeddings']
|
||||
output_data = []
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
output_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding
|
||||
}
|
||||
)
|
||||
model_response["object"] = "list"
|
||||
model_response["data"] = output_data
|
||||
model_response["model"] = model
|
||||
input_tokens = 0
|
||||
for text in input:
|
||||
input_tokens+=len(encoding.encode(text))
|
||||
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": input_tokens,
|
||||
"total_tokens": input_tokens,
|
||||
}
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
@ -54,6 +54,7 @@ from litellm.utils import (
|
|||
get_secret,
|
||||
CustomStreamWrapper,
|
||||
ModelResponse,
|
||||
EmbeddingResponse,
|
||||
read_config_args,
|
||||
)
|
||||
|
||||
|
@ -1352,7 +1353,15 @@ def batch_completion_models_all_responses(*args, **kwargs):
|
|||
60
|
||||
) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
|
||||
def embedding(
|
||||
model, input=[], azure=False, force_timeout=60, litellm_call_id=None, litellm_logging_obj=None, logger_fn=None, caching=False,
|
||||
model,
|
||||
input=[],
|
||||
azure=False,
|
||||
force_timeout=60,
|
||||
litellm_call_id=None,
|
||||
litellm_logging_obj=None,
|
||||
logger_fn=None,
|
||||
caching=False,
|
||||
api_key=None,
|
||||
):
|
||||
try:
|
||||
response = None
|
||||
|
@ -1393,6 +1402,23 @@ def embedding(
|
|||
)
|
||||
## EMBEDDING CALL
|
||||
response = openai.Embedding.create(input=input, model=model)
|
||||
elif model in litellm.cohere_embedding_models:
|
||||
cohere_key = (
|
||||
api_key
|
||||
or litellm.cohere_key
|
||||
or get_secret("COHERE_API_KEY")
|
||||
or get_secret("CO_API_KEY")
|
||||
or litellm.api_key
|
||||
)
|
||||
response = cohere.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
encoding=encoding,
|
||||
api_key=cohere_key,
|
||||
logging_obj=logging,
|
||||
model_response= EmbeddingResponse()
|
||||
|
||||
)
|
||||
else:
|
||||
args = locals()
|
||||
raise ValueError(f"No valid embedding model args passed in - {args}")
|
||||
|
|
|
@ -8,18 +8,28 @@ sys.path.insert(
|
|||
import litellm
|
||||
from litellm import embedding, completion
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
litellm.set_verbose = False
|
||||
|
||||
def test_openai_embedding():
|
||||
try:
|
||||
response = embedding(
|
||||
model="text-embedding-ada-002", input=["good morning from litellm"]
|
||||
model="text-embedding-ada-002", input=["good morning from litellm", "this is another item"]
|
||||
)
|
||||
print(response)
|
||||
# Add any assertions here to check the response
|
||||
# print(f"response: {str(response)}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_openai_embedding()
|
||||
|
||||
def test_cohere_embedding():
|
||||
try:
|
||||
response = embedding(
|
||||
model="embed-english-v2.0", input=["good morning from litellm", "this is another item"]
|
||||
)
|
||||
print(f"response:", response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_cohere_embedding()
|
||||
|
||||
|
|
|
@ -173,6 +173,19 @@ class ModelResponse(OpenAIObject):
|
|||
d["choices"] = [choice.to_dict_recursive() for choice in self.choices]
|
||||
return d
|
||||
|
||||
class EmbeddingResponse(OpenAIObject):
|
||||
def __init__(self, id=None, choices=None, created=None, model=None, usage=None, stream=False, response_ms=None, **params):
|
||||
self.object = "list"
|
||||
if response_ms:
|
||||
self._response_ms = response_ms
|
||||
else:
|
||||
self._response_ms = None
|
||||
self.data = []
|
||||
self.model = model
|
||||
|
||||
def to_dict_recursive(self):
|
||||
d = super().to_dict_recursive()
|
||||
return d
|
||||
|
||||
############################################################
|
||||
def print_verbose(print_statement):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue