forked from phoenix/litellm-mirror
feat(cohere/embed): pass client to async embedding
This commit is contained in:
parent
1a3fb18a64
commit
94fe135524
2 changed files with 28 additions and 9 deletions
|
@ -74,6 +74,7 @@ async def async_embedding(
|
|||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
|
||||
if client is None:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.COHERE,
|
||||
|
@ -151,6 +152,11 @@ def embedding(
|
|||
api_key=api_key,
|
||||
headers=headers,
|
||||
encoding=encoding,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
|
|
|
@ -1082,18 +1082,31 @@ def test_cohere_img_embeddings(input, input_type):
|
|||
assert response.usage.prompt_tokens_details.text_tokens > 0
|
||||
|
||||
|
||||
def test_embedding_with_extra_headers():
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_with_extra_headers(sync_mode):
|
||||
input = ["hello world"]
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||
|
||||
client = HTTPHandler()
|
||||
if sync_mode:
|
||||
client = HTTPHandler()
|
||||
else:
|
||||
client = AsyncHTTPHandler()
|
||||
|
||||
data = {
|
||||
"model": "cohere/embed-english-v3.0",
|
||||
"input": input,
|
||||
"extra_headers": {"my-test-param": "hello-world"},
|
||||
"client": client,
|
||||
}
|
||||
with patch.object(client, "post") as mock_post:
|
||||
embedding(
|
||||
model="cohere/embed-english-v3.0",
|
||||
input=input,
|
||||
extra_headers={"my-test-param": "hello-world"},
|
||||
client=client,
|
||||
)
|
||||
try:
|
||||
if sync_mode:
|
||||
embedding(**data)
|
||||
else:
|
||||
await litellm.aembedding(**data)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
assert "my-test-param" in mock_post.call_args.kwargs["headers"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue