From aaa4a32d65134d82736a403acb4e8f4056e4cfcd Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 12 Jun 2024 10:11:44 -0700 Subject: [PATCH] feat - support vertex text input --- litellm/llms/vertex_ai.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index c35aa502c9..3b35562ecf 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -1385,7 +1385,7 @@ def embedding( message="vertexai import failed please run `pip install google-cloud-aiplatform`", ) - from vertexai.language_models import TextEmbeddingModel + from vertexai.language_models import TextEmbeddingModel, TextEmbeddingInput import google.auth # type: ignore ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 @@ -1416,6 +1416,19 @@ def embedding( if isinstance(input, str): input = [input] + """ + VertexAI supports passing embedding input like this: + input=[ + { + "text": "good morning from litellm", + "task_type": "RETRIEVAL_DOCUMENT" + } + ], + + In this scenario we cast it to TextEmbeddingInput + """ + input = [TextEmbeddingInput(**x) for x in input if isinstance(x, dict)] + try: llm_model = TextEmbeddingModel.from_pretrained(model) except Exception as e: @@ -1453,6 +1466,7 @@ def embedding( logging_obj.post_call(input=input, api_key=None, original_response=embeddings) ## Populate OpenAI compliant dictionary embedding_response = [] + input_tokens: int = 0 for idx, embedding in enumerate(embeddings): embedding_response.append( { @@ -1461,14 +1475,10 @@ def embedding( "embedding": embedding.values, } ) + input_tokens += embedding.statistics.token_count model_response["object"] = "list" model_response["data"] = embedding_response model_response["model"] = model - input_tokens = 0 - - input_str = "".join(input) - - input_tokens += len(encoding.encode(input_str)) usage = Usage( prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens @@ -1511,6 +1521,7 @@ async def async_embedding( logging_obj.post_call(input=input, api_key=None, original_response=embeddings) ## Populate OpenAI compliant dictionary embedding_response = [] + input_tokens: int = 0 for idx, embedding in enumerate(embeddings): embedding_response.append( { @@ -1519,18 +1530,13 @@ async def async_embedding( "embedding": embedding.values, } ) + input_tokens += embedding.statistics.token_count + model_response["object"] = "list" model_response["data"] = embedding_response model_response["model"] = model - input_tokens = 0 - - input_str = "".join(input) - - input_tokens += len(encoding.encode(input_str)) - usage = Usage( prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens ) model_response.usage = usage - return model_response