diff --git a/litellm/llms/cohere/embed/handler.py b/litellm/llms/cohere/embed/handler.py index 5b224c375..afeba10b5 100644 --- a/litellm/llms/cohere/embed/handler.py +++ b/litellm/llms/cohere/embed/handler.py @@ -74,6 +74,7 @@ async def async_embedding( }, ) ## COMPLETION CALL + if client is None: client = get_async_httpx_client( llm_provider=litellm.LlmProviders.COHERE, @@ -151,6 +152,11 @@ def embedding( api_key=api_key, headers=headers, encoding=encoding, + client=( + client + if client is not None and isinstance(client, AsyncHTTPHandler) + else None + ), ) ## LOGGING diff --git a/tests/local_testing/test_embedding.py b/tests/local_testing/test_embedding.py index 97707a234..23d712b00 100644 --- a/tests/local_testing/test_embedding.py +++ b/tests/local_testing/test_embedding.py @@ -1082,18 +1082,31 @@ def test_cohere_img_embeddings(input, input_type): assert response.usage.prompt_tokens_details.text_tokens > 0 -def test_embedding_with_extra_headers(): +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_embedding_with_extra_headers(sync_mode): input = ["hello world"] - from litellm.llms.custom_httpx.http_handler import HTTPHandler + from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler - client = HTTPHandler() + if sync_mode: + client = HTTPHandler() + else: + client = AsyncHTTPHandler() + data = { + "model": "cohere/embed-english-v3.0", + "input": input, + "extra_headers": {"my-test-param": "hello-world"}, + "client": client, + } with patch.object(client, "post") as mock_post: - embedding( - model="cohere/embed-english-v3.0", - input=input, - extra_headers={"my-test-param": "hello-world"}, - client=client, - ) + try: + if sync_mode: + embedding(**data) + else: + await litellm.aembedding(**data) + except Exception as e: + print(e) + mock_post.assert_called_once() assert "my-test-param" in mock_post.call_args.kwargs["headers"]