forked from phoenix/litellm-mirror
fix(cohere.py): support async cohere embedding calls
This commit is contained in:
parent
185a6857f9
commit
9b2eb1702b
3 changed files with 132 additions and 30 deletions
|
@ -4,12 +4,14 @@ import time
|
|||
import traceback
|
||||
import types
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import Choices, Message, ModelResponse, Usage
|
||||
|
||||
|
||||
|
@ -246,14 +248,98 @@ def completion(
|
|||
return model_response
|
||||
|
||||
|
||||
def _process_embedding_response(
|
||||
embeddings: list,
|
||||
model_response: litellm.EmbeddingResponse,
|
||||
model: str,
|
||||
encoding: Any,
|
||||
input: list,
|
||||
) -> litellm.EmbeddingResponse:
|
||||
output_data = []
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
output_data.append(
|
||||
{"object": "embedding", "index": idx, "embedding": embedding}
|
||||
)
|
||||
model_response.object = "list"
|
||||
model_response.data = output_data
|
||||
model_response.model = model
|
||||
input_tokens = 0
|
||||
for text in input:
|
||||
input_tokens += len(encoding.encode(text))
|
||||
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
Usage(
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||
),
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
async def async_embedding(
|
||||
model: str,
|
||||
data: dict,
|
||||
input: list,
|
||||
model_response: litellm.utils.EmbeddingResponse,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
api_base: str,
|
||||
api_key: Optional[str],
|
||||
headers: dict,
|
||||
encoding: Callable,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
):
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
if client is None:
|
||||
client = AsyncHTTPHandler(concurrent_limit=1)
|
||||
|
||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response,
|
||||
)
|
||||
|
||||
embeddings = response.json()["embeddings"]
|
||||
|
||||
## PROCESS RESPONSE ##
|
||||
return _process_embedding_response(
|
||||
embeddings=embeddings,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
encoding=encoding,
|
||||
input=input,
|
||||
)
|
||||
|
||||
|
||||
def embedding(
|
||||
model: str,
|
||||
input: list,
|
||||
model_response: litellm.EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
encoding=None,
|
||||
optional_params=None,
|
||||
aembedding: Optional[bool] = None,
|
||||
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
headers = validate_environment(api_key)
|
||||
embed_url = "https://api.cohere.ai/v1/embed"
|
||||
|
@ -270,8 +356,26 @@ def embedding(
|
|||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
## ROUTING
|
||||
if aembedding is True:
|
||||
return async_embedding(
|
||||
model=model,
|
||||
data=data,
|
||||
input=input,
|
||||
model_response=model_response,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
api_base=embed_url,
|
||||
api_key=api_key,
|
||||
headers=headers,
|
||||
encoding=encoding,
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = requests.post(embed_url, headers=headers, data=json.dumps(data))
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = HTTPHandler(concurrent_limit=1)
|
||||
response = client.post(embed_url, headers=headers, data=json.dumps(data))
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
|
@ -293,23 +397,11 @@ def embedding(
|
|||
if response.status_code != 200:
|
||||
raise CohereError(message=response.text, status_code=response.status_code)
|
||||
embeddings = response.json()["embeddings"]
|
||||
output_data = []
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
output_data.append(
|
||||
{"object": "embedding", "index": idx, "embedding": embedding}
|
||||
)
|
||||
model_response.object = "list"
|
||||
model_response.data = output_data
|
||||
model_response.model = model
|
||||
input_tokens = 0
|
||||
for text in input:
|
||||
input_tokens += len(encoding.encode(text))
|
||||
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
Usage(
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||
),
|
||||
return _process_embedding_response(
|
||||
embeddings=embeddings,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
encoding=encoding,
|
||||
input=input,
|
||||
)
|
||||
return model_response
|
||||
|
|
|
@ -3114,6 +3114,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
|||
or custom_llm_provider == "vertex_ai"
|
||||
or custom_llm_provider == "databricks"
|
||||
or custom_llm_provider == "watsonx"
|
||||
or custom_llm_provider == "cohere"
|
||||
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
# Await normally
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -3440,9 +3441,12 @@ def embedding(
|
|||
input=input,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
api_key=cohere_key,
|
||||
api_key=cohere_key, # type: ignore
|
||||
logging_obj=logging,
|
||||
model_response=EmbeddingResponse(),
|
||||
aembedding=aembedding,
|
||||
timeout=float(timeout),
|
||||
client=client,
|
||||
)
|
||||
elif custom_llm_provider == "huggingface":
|
||||
api_key = (
|
||||
|
|
|
@ -257,14 +257,20 @@ def test_openai_azure_embedding_optional_arg(mocker):
|
|||
# test_openai_embedding()
|
||||
|
||||
|
||||
def test_cohere_embedding():
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_cohere_embedding(sync_mode):
|
||||
try:
|
||||
# litellm.set_verbose=True
|
||||
response = embedding(
|
||||
model="embed-english-v2.0",
|
||||
input=["good morning from litellm", "this is another item"],
|
||||
input_type="search_query",
|
||||
)
|
||||
data = {
|
||||
"model": "embed-english-v2.0",
|
||||
"input": ["good morning from litellm", "this is another item"],
|
||||
"input_type": "search_query",
|
||||
}
|
||||
if sync_mode:
|
||||
response = embedding(**data)
|
||||
else:
|
||||
response = await litellm.aembedding(**data)
|
||||
print(f"response:", response)
|
||||
|
||||
assert isinstance(response.usage, litellm.Usage)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue