diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 56f50c119..d42bd003f 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -945,6 +945,7 @@ def embedding( encoding=None, vertex_project=None, vertex_location=None, + aembedding=False, ): # logic for parsing in - calling - parsing out model embedding calls try: @@ -972,9 +973,95 @@ def embedding( try: llm_model = TextEmbeddingModel.from_pretrained(model) + except Exception as e: + raise VertexAIError(status_code=422, message=str(e)) + + if aembedding == True: + return async_embedding( + model=model, + client=llm_model, + input=input, + logging_obj=logging_obj, + model_response=model_response, + optional_params=optional_params, + encoding=encoding, + ) + + request_str = f"""embeddings = llm_model.get_embeddings({input})""" + ## LOGGING PRE-CALL + logging_obj.pre_call( + input=input, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + + try: embeddings = llm_model.get_embeddings(input) except Exception as e: raise VertexAIError(status_code=500, message=str(e)) + + ## LOGGING POST-CALL + logging_obj.post_call(input=input, api_key=None, original_response=embeddings) + ## Populate OpenAI compliant dictionary + embedding_response = [] + for idx, embedding in enumerate(embeddings): + embedding_response.append( + { + "object": "embedding", + "index": idx, + "embedding": embedding.values, + } + ) + model_response["object"] = "list" + model_response["data"] = embedding_response + model_response["model"] = model + input_tokens = 0 + + input_str = "".join(input) + + input_tokens += len(encoding.encode(input_str)) + + usage = Usage( + prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + ) + model_response.usage = usage + + return model_response + + +async def async_embedding( + model: str, + input: Union[list, str], + logging_obj=None, + model_response=None, + optional_params=None, + encoding=None, + client=None, +): + """ + Async embedding implementation + """ + request_str = f"""embeddings = llm_model.get_embeddings({input})""" + ## LOGGING PRE-CALL + logging_obj.pre_call( + input=input, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + + try: + embeddings = await client.get_embeddings_async(input) + except Exception as e: + raise VertexAIError(status_code=500, message=str(e)) + + ## LOGGING POST-CALL + logging_obj.post_call(input=input, api_key=None, original_response=embeddings) ## Populate OpenAI compliant dictionary embedding_response = [] for idx, embedding in enumerate(embeddings): diff --git a/litellm/main.py b/litellm/main.py index 9a24db87a..f37756b11 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2211,6 +2211,7 @@ async def aembedding(*args, **kwargs): or custom_llm_provider == "deepinfra" or custom_llm_provider == "perplexity" or custom_llm_provider == "ollama" + or custom_llm_provider == "vertex_ai" ): # currently implemented aiohttp calls for just azure and openai, soon all. # Await normally init_response = await loop.run_in_executor(None, func_with_context) @@ -2549,6 +2550,7 @@ def embedding( model_response=EmbeddingResponse(), vertex_project=vertex_ai_project, vertex_location=vertex_ai_location, + aembedding=aembedding, ) elif custom_llm_provider == "oobabooga": response = oobabooga.embedding( diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index 565e86fc2..4637a79e0 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -243,6 +243,19 @@ def test_vertexai_embedding(): pytest.fail(f"Error occurred: {e}") +@pytest.mark.asyncio +async def test_vertexai_aembedding(): + try: + # litellm.set_verbose=True + response = await litellm.aembedding( + model="textembedding-gecko@001", + input=["good morning from litellm", "this is another item"], + ) + print(f"response: {response}") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_bedrock_embedding_titan(): try: # this tests if we support str input for bedrock embedding