forked from phoenix/litellm-mirror
feat(cohere/embed): pass client to async embedding
This commit is contained in:
parent
1a3fb18a64
commit
94fe135524
2 changed files with 28 additions and 9 deletions
|
@ -74,6 +74,7 @@ async def async_embedding(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
|
|
||||||
if client is None:
|
if client is None:
|
||||||
client = get_async_httpx_client(
|
client = get_async_httpx_client(
|
||||||
llm_provider=litellm.LlmProviders.COHERE,
|
llm_provider=litellm.LlmProviders.COHERE,
|
||||||
|
@ -151,6 +152,11 @@ def embedding(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
client=(
|
||||||
|
client
|
||||||
|
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
|
@ -1082,18 +1082,31 @@ def test_cohere_img_embeddings(input, input_type):
|
||||||
assert response.usage.prompt_tokens_details.text_tokens > 0
|
assert response.usage.prompt_tokens_details.text_tokens > 0
|
||||||
|
|
||||||
|
|
||||||
def test_embedding_with_extra_headers():
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embedding_with_extra_headers(sync_mode):
|
||||||
input = ["hello world"]
|
input = ["hello world"]
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||||
|
|
||||||
client = HTTPHandler()
|
if sync_mode:
|
||||||
|
client = HTTPHandler()
|
||||||
|
else:
|
||||||
|
client = AsyncHTTPHandler()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": "cohere/embed-english-v3.0",
|
||||||
|
"input": input,
|
||||||
|
"extra_headers": {"my-test-param": "hello-world"},
|
||||||
|
"client": client,
|
||||||
|
}
|
||||||
with patch.object(client, "post") as mock_post:
|
with patch.object(client, "post") as mock_post:
|
||||||
embedding(
|
try:
|
||||||
model="cohere/embed-english-v3.0",
|
if sync_mode:
|
||||||
input=input,
|
embedding(**data)
|
||||||
extra_headers={"my-test-param": "hello-world"},
|
else:
|
||||||
client=client,
|
await litellm.aembedding(**data)
|
||||||
)
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
mock_post.assert_called_once()
|
||||||
assert "my-test-param" in mock_post.call_args.kwargs["headers"]
|
assert "my-test-param" in mock_post.call_args.kwargs["headers"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue