diff --git a/litellm/main.py b/litellm/main.py index f69454aaad..8ec13f62a9 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3244,7 +3244,9 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse: func_with_context = partial(ctx.run, func) _, custom_llm_provider, _, _ = get_llm_provider( - model=model, api_base=kwargs.get("api_base", None) + model=model, + custom_llm_provider=kwargs.get("custom_llm_provider", None), + api_base=kwargs.get("api_base", None), ) # Await normally