mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat - support vertex text input
This commit is contained in:
parent
b9f50d83c4
commit
aaa4a32d65
1 changed files with 19 additions and 13 deletions
|
@ -1385,7 +1385,7 @@ def embedding(
|
|||
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
|
||||
)
|
||||
|
||||
from vertexai.language_models import TextEmbeddingModel
|
||||
from vertexai.language_models import TextEmbeddingModel, TextEmbeddingInput
|
||||
import google.auth # type: ignore
|
||||
|
||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||
|
@ -1416,6 +1416,19 @@ def embedding(
|
|||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
"""
|
||||
VertexAI supports passing embedding input like this:
|
||||
input=[
|
||||
{
|
||||
"text": "good morning from litellm",
|
||||
"task_type": "RETRIEVAL_DOCUMENT"
|
||||
}
|
||||
],
|
||||
|
||||
In this scenario we cast it to TextEmbeddingInput
|
||||
"""
|
||||
input = [TextEmbeddingInput(**x) for x in input if isinstance(x, dict)]
|
||||
|
||||
try:
|
||||
llm_model = TextEmbeddingModel.from_pretrained(model)
|
||||
except Exception as e:
|
||||
|
@ -1453,6 +1466,7 @@ def embedding(
|
|||
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
||||
## Populate OpenAI compliant dictionary
|
||||
embedding_response = []
|
||||
input_tokens: int = 0
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
embedding_response.append(
|
||||
{
|
||||
|
@ -1461,14 +1475,10 @@ def embedding(
|
|||
"embedding": embedding.values,
|
||||
}
|
||||
)
|
||||
input_tokens += embedding.statistics.token_count
|
||||
model_response["object"] = "list"
|
||||
model_response["data"] = embedding_response
|
||||
model_response["model"] = model
|
||||
input_tokens = 0
|
||||
|
||||
input_str = "".join(input)
|
||||
|
||||
input_tokens += len(encoding.encode(input_str))
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||
|
@ -1511,6 +1521,7 @@ async def async_embedding(
|
|||
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
||||
## Populate OpenAI compliant dictionary
|
||||
embedding_response = []
|
||||
input_tokens: int = 0
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
embedding_response.append(
|
||||
{
|
||||
|
@ -1519,18 +1530,13 @@ async def async_embedding(
|
|||
"embedding": embedding.values,
|
||||
}
|
||||
)
|
||||
input_tokens += embedding.statistics.token_count
|
||||
|
||||
model_response["object"] = "list"
|
||||
model_response["data"] = embedding_response
|
||||
model_response["model"] = model
|
||||
input_tokens = 0
|
||||
|
||||
input_str = "".join(input)
|
||||
|
||||
input_tokens += len(encoding.encode(input_str))
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||
)
|
||||
model_response.usage = usage
|
||||
|
||||
return model_response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue