diff --git a/tests/local_testing/test_embedding.py b/tests/local_testing/test_embedding.py index d7988e690..97707a234 100644 --- a/tests/local_testing/test_embedding.py +++ b/tests/local_testing/test_embedding.py @@ -1080,3 +1080,20 @@ def test_cohere_img_embeddings(input, input_type): assert response.usage.prompt_tokens_details.image_tokens > 0 else: assert response.usage.prompt_tokens_details.text_tokens > 0 + + +def test_embedding_with_extra_headers(): + input = ["hello world"] + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + client = HTTPHandler() + + with patch.object(client, "post") as mock_post: + embedding( + model="cohere/embed-english-v3.0", + input=input, + extra_headers={"my-test-param": "hello-world"}, + client=client, + ) + + assert "my-test-param" in mock_post.call_args.kwargs["headers"]