diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index d2c5b6ff4..3f15e2f69 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -323,7 +323,26 @@ class OpenAIChatCompletion(BaseLLM): raise OpenAIError(status_code=408, message=f"{type(e).__name__}") else: raise OpenAIError(status_code=500, message=f"{str(e)}") - + async def aembedding( + self, + data: dict, + model_response: ModelResponse, + timeout: float, + api_key: Optional[str]=None, + api_base: Optional[str]=None, + client=None, + max_retries=None, + ): + response = None + try: + if client is None: + openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries) + else: + openai_aclient = client + response = await openai_aclient.embeddings.create(**data) + return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") + except Exception as e: + raise e def embedding(self, model: str, input: list, @@ -334,6 +353,7 @@ class OpenAIChatCompletion(BaseLLM): logging_obj=None, optional_params=None, client=None, + aembedding=None, ): super().embedding() exception_mapping_worked = False @@ -347,6 +367,9 @@ class OpenAIChatCompletion(BaseLLM): max_retries = data.pop("max_retries", 2) if not isinstance(max_retries, int): raise OpenAIError(status_code=422, message="max retries must be an int") + if aembedding == True: + response = self.aembedding(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) + return response if client is None: openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) else: diff --git a/litellm/main.py b/litellm/main.py index 03807b9d7..61552eead 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1668,17 +1668,43 @@ async def aembedding(*args, **kwargs): - `response` (Any): The response returned by the `embedding` function. """ loop = asyncio.get_event_loop() + model = args[0] if len(args) > 0 else kwargs["model"] + ### PASS ARGS TO Embedding ### + kwargs["aembedding"] = True + custom_llm_provider = None + try: + # Use a partial function to pass your keyword arguments + func = partial(embedding, *args, **kwargs) - # Use a partial function to pass your keyword arguments - func = partial(embedding, *args, **kwargs) + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) - # Add the context to the function - ctx = contextvars.copy_context() - func_with_context = partial(ctx.run, func) + _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) - # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) - return response + if (custom_llm_provider == "openai" + or custom_llm_provider == "azure" + or custom_llm_provider == "custom_openai" + or custom_llm_provider == "anyscale" + or custom_llm_provider == "openrouter" + or custom_llm_provider == "deepinfra" + or custom_llm_provider == "perplexity" + or custom_llm_provider == "huggingface"): # currently implemented aiohttp calls for just azure and openai, soon all. + # Await normally + init_response = await loop.run_in_executor(None, func_with_context) + if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO + response = init_response + elif asyncio.iscoroutine(init_response): + response = await init_response + else: + # Call the synchronous function using run_in_executor + response = await loop.run_in_executor(None, func_with_context) + return response + except Exception as e: + custom_llm_provider = custom_llm_provider or "openai" + raise exception_type( + model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, + ) @client def embedding( @@ -1725,6 +1751,7 @@ def embedding( client = kwargs.pop("client", None) rpm = kwargs.pop("rpm", None) tpm = kwargs.pop("tpm", None) + aembedding = kwargs.pop("aembedding", None) optional_params = {} for param in kwargs: @@ -1809,7 +1836,8 @@ def embedding( timeout=timeout, model_response=EmbeddingResponse(), optional_params=optional_params, - client=client + client=client, + aembedding=aembedding, ) elif model in litellm.cohere_embedding_models: cohere_key = ( diff --git a/litellm/utils.py b/litellm/utils.py index b80732d49..924e14f5c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1303,6 +1303,8 @@ def client(original_function): return result elif "acompletion" in kwargs and kwargs["acompletion"] == True: return result + elif "aembedding" in kwargs and kwargs["aembedding"] == True: + return result ### POST-CALL RULES ### post_call_processing(original_response=result, model=model)