diff --git a/litellm/main.py b/litellm/main.py index f5618c055..86086041f 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1447,6 +1447,21 @@ def batch_completion_models_all_responses(*args, **kwargs): return responses ### EMBEDDING ENDPOINTS #################### + +async def aembedding(*args, **kwargs): + loop = asyncio.get_event_loop() + + # Use a partial function to pass your keyword arguments + func = partial(embedding, *args, **kwargs) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + + # Call the synchronous function using run_in_executor + response = await loop.run_in_executor(None, func_with_context) + return response + @client @timeout( # type: ignore 60 diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index 3d79100ee..03d613055 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -91,4 +91,20 @@ def test_cohere_embedding(): # pytest.fail(f"Error occurred: {e}") # test_hf_embedding() +# test async embeddings +def test_aembedding(): + import asyncio + async def embedding_call(): + try: + response = await litellm.aembedding( + model="text-embedding-ada-002", + input=["good morning from litellm", "this is another item"] + ) + print(response) + except: + print(f"error occurred: {traceback.format_exc()}") + pass + asyncio.run(embedding_call()) + +# test_aembedding()