fix(cohere.py): support async cohere embedding calls

This commit is contained in:
Krrish Dholakia 2024-07-30 14:49:07 -07:00
parent 185a6857f9
commit 9b2eb1702b
3 changed files with 132 additions and 30 deletions

View file

@ -4,12 +4,14 @@ import time
import traceback import traceback
import types import types
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Any, Callable, Optional, Union
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
import litellm import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import Choices, Message, ModelResponse, Usage from litellm.utils import Choices, Message, ModelResponse, Usage
@ -246,14 +248,98 @@ def completion(
return model_response return model_response
def _process_embedding_response(
embeddings: list,
model_response: litellm.EmbeddingResponse,
model: str,
encoding: Any,
input: list,
) -> litellm.EmbeddingResponse:
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))
setattr(
model_response,
"usage",
Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
),
)
return model_response
async def async_embedding(
model: str,
data: dict,
input: list,
model_response: litellm.utils.EmbeddingResponse,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
api_base: str,
api_key: Optional[str],
headers: dict,
encoding: Callable,
client: Optional[AsyncHTTPHandler] = None,
):
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
)
## COMPLETION CALL
if client is None:
client = AsyncHTTPHandler(concurrent_limit=1)
response = await client.post(api_base, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()["embeddings"]
## PROCESS RESPONSE ##
return _process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
encoding=encoding,
input=input,
)
def embedding( def embedding(
model: str, model: str,
input: list, input: list,
model_response: litellm.EmbeddingResponse, model_response: litellm.EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
encoding: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
logging_obj=None, aembedding: Optional[bool] = None,
encoding=None, timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
optional_params=None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
): ):
headers = validate_environment(api_key) headers = validate_environment(api_key)
embed_url = "https://api.cohere.ai/v1/embed" embed_url = "https://api.cohere.ai/v1/embed"
@ -270,8 +356,26 @@ def embedding(
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
## ROUTING
if aembedding is True:
return async_embedding(
model=model,
data=data,
input=input,
model_response=model_response,
timeout=timeout,
logging_obj=logging_obj,
optional_params=optional_params,
api_base=embed_url,
api_key=api_key,
headers=headers,
encoding=encoding,
)
## COMPLETION CALL ## COMPLETION CALL
response = requests.post(embed_url, headers=headers, data=json.dumps(data)) if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(concurrent_limit=1)
response = client.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
@ -293,23 +397,11 @@ def embedding(
if response.status_code != 200: if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code) raise CohereError(message=response.text, status_code=response.status_code)
embeddings = response.json()["embeddings"] embeddings = response.json()["embeddings"]
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))
setattr( return _process_embedding_response(
model_response, embeddings=embeddings,
"usage", model_response=model_response,
Usage( model=model,
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens encoding=encoding,
), input=input,
) )
return model_response

View file

@ -3114,6 +3114,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
or custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "databricks" or custom_llm_provider == "databricks"
or custom_llm_provider == "watsonx" or custom_llm_provider == "watsonx"
or custom_llm_provider == "cohere"
): # currently implemented aiohttp calls for just azure and openai, soon all. ): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally # Await normally
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -3440,9 +3441,12 @@ def embedding(
input=input, input=input,
optional_params=optional_params, optional_params=optional_params,
encoding=encoding, encoding=encoding,
api_key=cohere_key, api_key=cohere_key, # type: ignore
logging_obj=logging, logging_obj=logging,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
aembedding=aembedding,
timeout=float(timeout),
client=client,
) )
elif custom_llm_provider == "huggingface": elif custom_llm_provider == "huggingface":
api_key = ( api_key = (

View file

@ -257,14 +257,20 @@ def test_openai_azure_embedding_optional_arg(mocker):
# test_openai_embedding() # test_openai_embedding()
def test_cohere_embedding(): @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_cohere_embedding(sync_mode):
try: try:
# litellm.set_verbose=True # litellm.set_verbose=True
response = embedding( data = {
model="embed-english-v2.0", "model": "embed-english-v2.0",
input=["good morning from litellm", "this is another item"], "input": ["good morning from litellm", "this is another item"],
input_type="search_query", "input_type": "search_query",
) }
if sync_mode:
response = embedding(**data)
else:
response = await litellm.aembedding(**data)
print(f"response:", response) print(f"response:", response)
assert isinstance(response.usage, litellm.Usage) assert isinstance(response.usage, litellm.Usage)