feat - support vertex text input

This commit is contained in:
Ishaan Jaff 2024-06-12 10:11:44 -07:00
parent b9f50d83c4
commit aaa4a32d65

View file

@ -1385,7 +1385,7 @@ def embedding(
message="vertexai import failed please run `pip install google-cloud-aiplatform`", message="vertexai import failed please run `pip install google-cloud-aiplatform`",
) )
from vertexai.language_models import TextEmbeddingModel from vertexai.language_models import TextEmbeddingModel, TextEmbeddingInput
import google.auth # type: ignore import google.auth # type: ignore
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
@ -1416,6 +1416,19 @@ def embedding(
if isinstance(input, str): if isinstance(input, str):
input = [input] input = [input]
"""
VertexAI supports passing embedding input like this:
input=[
{
"text": "good morning from litellm",
"task_type": "RETRIEVAL_DOCUMENT"
}
],
In this scenario we cast it to TextEmbeddingInput
"""
input = [TextEmbeddingInput(**x) for x in input if isinstance(x, dict)]
try: try:
llm_model = TextEmbeddingModel.from_pretrained(model) llm_model = TextEmbeddingModel.from_pretrained(model)
except Exception as e: except Exception as e:
@ -1453,6 +1466,7 @@ def embedding(
logging_obj.post_call(input=input, api_key=None, original_response=embeddings) logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
## Populate OpenAI compliant dictionary ## Populate OpenAI compliant dictionary
embedding_response = [] embedding_response = []
input_tokens: int = 0
for idx, embedding in enumerate(embeddings): for idx, embedding in enumerate(embeddings):
embedding_response.append( embedding_response.append(
{ {
@ -1461,14 +1475,10 @@ def embedding(
"embedding": embedding.values, "embedding": embedding.values,
} }
) )
input_tokens += embedding.statistics.token_count
model_response["object"] = "list" model_response["object"] = "list"
model_response["data"] = embedding_response model_response["data"] = embedding_response
model_response["model"] = model model_response["model"] = model
input_tokens = 0
input_str = "".join(input)
input_tokens += len(encoding.encode(input_str))
usage = Usage( usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
@ -1511,6 +1521,7 @@ async def async_embedding(
logging_obj.post_call(input=input, api_key=None, original_response=embeddings) logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
## Populate OpenAI compliant dictionary ## Populate OpenAI compliant dictionary
embedding_response = [] embedding_response = []
input_tokens: int = 0
for idx, embedding in enumerate(embeddings): for idx, embedding in enumerate(embeddings):
embedding_response.append( embedding_response.append(
{ {
@ -1519,18 +1530,13 @@ async def async_embedding(
"embedding": embedding.values, "embedding": embedding.values,
} }
) )
input_tokens += embedding.statistics.token_count
model_response["object"] = "list" model_response["object"] = "list"
model_response["data"] = embedding_response model_response["data"] = embedding_response
model_response["model"] = model model_response["model"] = model
input_tokens = 0
input_str = "".join(input)
input_tokens += len(encoding.encode(input_str))
usage = Usage( usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response