fix(cohere.py): support async cohere embedding calls

This commit is contained in:
Krrish Dholakia 2024-07-30 14:49:07 -07:00
parent 185a6857f9
commit 9b2eb1702b
3 changed files with 132 additions and 30 deletions

View file

@ -4,12 +4,14 @@ import time
import traceback
import types
from enum import Enum
from typing import Callable, Optional
from typing import Any, Callable, Optional, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import Choices, Message, ModelResponse, Usage
@ -246,14 +248,98 @@ def completion(
return model_response
def _process_embedding_response(
embeddings: list,
model_response: litellm.EmbeddingResponse,
model: str,
encoding: Any,
input: list,
) -> litellm.EmbeddingResponse:
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))
setattr(
model_response,
"usage",
Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
),
)
return model_response
async def async_embedding(
model: str,
data: dict,
input: list,
model_response: litellm.utils.EmbeddingResponse,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
api_base: str,
api_key: Optional[str],
headers: dict,
encoding: Callable,
client: Optional[AsyncHTTPHandler] = None,
):
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
)
## COMPLETION CALL
if client is None:
client = AsyncHTTPHandler(concurrent_limit=1)
response = await client.post(api_base, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()["embeddings"]
## PROCESS RESPONSE ##
return _process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
encoding=encoding,
input=input,
)
def embedding(
model: str,
input: list,
model_response: litellm.EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
encoding: Any,
api_key: Optional[str] = None,
logging_obj=None,
encoding=None,
optional_params=None,
aembedding: Optional[bool] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
headers = validate_environment(api_key)
embed_url = "https://api.cohere.ai/v1/embed"
@ -270,8 +356,26 @@ def embedding(
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## ROUTING
if aembedding is True:
return async_embedding(
model=model,
data=data,
input=input,
model_response=model_response,
timeout=timeout,
logging_obj=logging_obj,
optional_params=optional_params,
api_base=embed_url,
api_key=api_key,
headers=headers,
encoding=encoding,
)
## COMPLETION CALL
response = requests.post(embed_url, headers=headers, data=json.dumps(data))
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(concurrent_limit=1)
response = client.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
@ -293,23 +397,11 @@ def embedding(
if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code)
embeddings = response.json()["embeddings"]
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))
setattr(
model_response,
"usage",
Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
),
return _process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
encoding=encoding,
input=input,
)
return model_response

View file

@ -3114,6 +3114,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "databricks"
or custom_llm_provider == "watsonx"
or custom_llm_provider == "cohere"
): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
@ -3440,9 +3441,12 @@ def embedding(
input=input,
optional_params=optional_params,
encoding=encoding,
api_key=cohere_key,
api_key=cohere_key, # type: ignore
logging_obj=logging,
model_response=EmbeddingResponse(),
aembedding=aembedding,
timeout=float(timeout),
client=client,
)
elif custom_llm_provider == "huggingface":
api_key = (

View file

@ -257,14 +257,20 @@ def test_openai_azure_embedding_optional_arg(mocker):
# test_openai_embedding()
def test_cohere_embedding():
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_cohere_embedding(sync_mode):
try:
# litellm.set_verbose=True
response = embedding(
model="embed-english-v2.0",
input=["good morning from litellm", "this is another item"],
input_type="search_query",
)
data = {
"model": "embed-english-v2.0",
"input": ["good morning from litellm", "this is another item"],
"input_type": "search_query",
}
if sync_mode:
response = embedding(**data)
else:
response = await litellm.aembedding(**data)
print(f"response:", response)
assert isinstance(response.usage, litellm.Usage)