mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(cohere.py): support async cohere embedding calls
This commit is contained in:
parent
185a6857f9
commit
9b2eb1702b
3 changed files with 132 additions and 30 deletions
|
@ -4,12 +4,14 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
import types
|
import types
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Optional
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.utils import Choices, Message, ModelResponse, Usage
|
from litellm.utils import Choices, Message, ModelResponse, Usage
|
||||||
|
|
||||||
|
|
||||||
|
@ -246,14 +248,98 @@ def completion(
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
|
def _process_embedding_response(
|
||||||
|
embeddings: list,
|
||||||
|
model_response: litellm.EmbeddingResponse,
|
||||||
|
model: str,
|
||||||
|
encoding: Any,
|
||||||
|
input: list,
|
||||||
|
) -> litellm.EmbeddingResponse:
|
||||||
|
output_data = []
|
||||||
|
for idx, embedding in enumerate(embeddings):
|
||||||
|
output_data.append(
|
||||||
|
{"object": "embedding", "index": idx, "embedding": embedding}
|
||||||
|
)
|
||||||
|
model_response.object = "list"
|
||||||
|
model_response.data = output_data
|
||||||
|
model_response.model = model
|
||||||
|
input_tokens = 0
|
||||||
|
for text in input:
|
||||||
|
input_tokens += len(encoding.encode(text))
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
model_response,
|
||||||
|
"usage",
|
||||||
|
Usage(
|
||||||
|
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
|
async def async_embedding(
|
||||||
|
model: str,
|
||||||
|
data: dict,
|
||||||
|
input: list,
|
||||||
|
model_response: litellm.utils.EmbeddingResponse,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
optional_params: dict,
|
||||||
|
api_base: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
headers: dict,
|
||||||
|
encoding: Callable,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"headers": headers,
|
||||||
|
"api_base": api_base,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
## COMPLETION CALL
|
||||||
|
if client is None:
|
||||||
|
client = AsyncHTTPHandler(concurrent_limit=1)
|
||||||
|
|
||||||
|
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = response.json()["embeddings"]
|
||||||
|
|
||||||
|
## PROCESS RESPONSE ##
|
||||||
|
return _process_embedding_response(
|
||||||
|
embeddings=embeddings,
|
||||||
|
model_response=model_response,
|
||||||
|
model=model,
|
||||||
|
encoding=encoding,
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def embedding(
|
def embedding(
|
||||||
model: str,
|
model: str,
|
||||||
input: list,
|
input: list,
|
||||||
model_response: litellm.EmbeddingResponse,
|
model_response: litellm.EmbeddingResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
optional_params: dict,
|
||||||
|
encoding: Any,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
logging_obj=None,
|
aembedding: Optional[bool] = None,
|
||||||
encoding=None,
|
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
||||||
optional_params=None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key)
|
headers = validate_environment(api_key)
|
||||||
embed_url = "https://api.cohere.ai/v1/embed"
|
embed_url = "https://api.cohere.ai/v1/embed"
|
||||||
|
@ -270,8 +356,26 @@ def embedding(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## ROUTING
|
||||||
|
if aembedding is True:
|
||||||
|
return async_embedding(
|
||||||
|
model=model,
|
||||||
|
data=data,
|
||||||
|
input=input,
|
||||||
|
model_response=model_response,
|
||||||
|
timeout=timeout,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_base=embed_url,
|
||||||
|
api_key=api_key,
|
||||||
|
headers=headers,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
response = requests.post(embed_url, headers=headers, data=json.dumps(data))
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
client = HTTPHandler(concurrent_limit=1)
|
||||||
|
response = client.post(embed_url, headers=headers, data=json.dumps(data))
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -293,23 +397,11 @@ def embedding(
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise CohereError(message=response.text, status_code=response.status_code)
|
raise CohereError(message=response.text, status_code=response.status_code)
|
||||||
embeddings = response.json()["embeddings"]
|
embeddings = response.json()["embeddings"]
|
||||||
output_data = []
|
|
||||||
for idx, embedding in enumerate(embeddings):
|
|
||||||
output_data.append(
|
|
||||||
{"object": "embedding", "index": idx, "embedding": embedding}
|
|
||||||
)
|
|
||||||
model_response.object = "list"
|
|
||||||
model_response.data = output_data
|
|
||||||
model_response.model = model
|
|
||||||
input_tokens = 0
|
|
||||||
for text in input:
|
|
||||||
input_tokens += len(encoding.encode(text))
|
|
||||||
|
|
||||||
setattr(
|
return _process_embedding_response(
|
||||||
model_response,
|
embeddings=embeddings,
|
||||||
"usage",
|
model_response=model_response,
|
||||||
Usage(
|
model=model,
|
||||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
encoding=encoding,
|
||||||
),
|
input=input,
|
||||||
)
|
)
|
||||||
return model_response
|
|
||||||
|
|
|
@ -3114,6 +3114,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
||||||
or custom_llm_provider == "vertex_ai"
|
or custom_llm_provider == "vertex_ai"
|
||||||
or custom_llm_provider == "databricks"
|
or custom_llm_provider == "databricks"
|
||||||
or custom_llm_provider == "watsonx"
|
or custom_llm_provider == "watsonx"
|
||||||
|
or custom_llm_provider == "cohere"
|
||||||
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||||
# Await normally
|
# Await normally
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -3440,9 +3441,12 @@ def embedding(
|
||||||
input=input,
|
input=input,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
api_key=cohere_key,
|
api_key=cohere_key, # type: ignore
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
|
aembedding=aembedding,
|
||||||
|
timeout=float(timeout),
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "huggingface":
|
elif custom_llm_provider == "huggingface":
|
||||||
api_key = (
|
api_key = (
|
||||||
|
|
|
@ -257,14 +257,20 @@ def test_openai_azure_embedding_optional_arg(mocker):
|
||||||
# test_openai_embedding()
|
# test_openai_embedding()
|
||||||
|
|
||||||
|
|
||||||
def test_cohere_embedding():
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cohere_embedding(sync_mode):
|
||||||
try:
|
try:
|
||||||
# litellm.set_verbose=True
|
# litellm.set_verbose=True
|
||||||
response = embedding(
|
data = {
|
||||||
model="embed-english-v2.0",
|
"model": "embed-english-v2.0",
|
||||||
input=["good morning from litellm", "this is another item"],
|
"input": ["good morning from litellm", "this is another item"],
|
||||||
input_type="search_query",
|
"input_type": "search_query",
|
||||||
)
|
}
|
||||||
|
if sync_mode:
|
||||||
|
response = embedding(**data)
|
||||||
|
else:
|
||||||
|
response = await litellm.aembedding(**data)
|
||||||
print(f"response:", response)
|
print(f"response:", response)
|
||||||
|
|
||||||
assert isinstance(response.usage, litellm.Usage)
|
assert isinstance(response.usage, litellm.Usage)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue