feat(cohere/embed): pass client to async embedding

This commit is contained in:
Krrish Dholakia 2024-11-23 00:47:26 +05:30
parent 1a3fb18a64
commit 94fe135524
2 changed files with 28 additions and 9 deletions

View file

@ -74,6 +74,7 @@ async def async_embedding(
},
)
## COMPLETION CALL
if client is None:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.COHERE,
@ -151,6 +152,11 @@ def embedding(
api_key=api_key,
headers=headers,
encoding=encoding,
client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
)
## LOGGING

View file

@ -1082,18 +1082,31 @@ def test_cohere_img_embeddings(input, input_type):
assert response.usage.prompt_tokens_details.text_tokens > 0
def test_embedding_with_extra_headers():
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_embedding_with_extra_headers(sync_mode):
input = ["hello world"]
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
client = HTTPHandler()
if sync_mode:
client = HTTPHandler()
else:
client = AsyncHTTPHandler()
data = {
"model": "cohere/embed-english-v3.0",
"input": input,
"extra_headers": {"my-test-param": "hello-world"},
"client": client,
}
with patch.object(client, "post") as mock_post:
embedding(
model="cohere/embed-english-v3.0",
input=input,
extra_headers={"my-test-param": "hello-world"},
client=client,
)
try:
if sync_mode:
embedding(**data)
else:
await litellm.aembedding(**data)
except Exception as e:
print(e)
mock_post.assert_called_once()
assert "my-test-param" in mock_post.call_args.kwargs["headers"]