mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
feat - support vertex text input
This commit is contained in:
parent
b9f50d83c4
commit
aaa4a32d65
1 changed files with 19 additions and 13 deletions
|
@ -1385,7 +1385,7 @@ def embedding(
|
||||||
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
|
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
|
||||||
)
|
)
|
||||||
|
|
||||||
from vertexai.language_models import TextEmbeddingModel
|
from vertexai.language_models import TextEmbeddingModel, TextEmbeddingInput
|
||||||
import google.auth # type: ignore
|
import google.auth # type: ignore
|
||||||
|
|
||||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||||
|
@ -1416,6 +1416,19 @@ def embedding(
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
input = [input]
|
input = [input]
|
||||||
|
|
||||||
|
"""
|
||||||
|
VertexAI supports passing embedding input like this:
|
||||||
|
input=[
|
||||||
|
{
|
||||||
|
"text": "good morning from litellm",
|
||||||
|
"task_type": "RETRIEVAL_DOCUMENT"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
|
||||||
|
In this scenario we cast it to TextEmbeddingInput
|
||||||
|
"""
|
||||||
|
input = [TextEmbeddingInput(**x) for x in input if isinstance(x, dict)]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
llm_model = TextEmbeddingModel.from_pretrained(model)
|
llm_model = TextEmbeddingModel.from_pretrained(model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1453,6 +1466,7 @@ def embedding(
|
||||||
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
||||||
## Populate OpenAI compliant dictionary
|
## Populate OpenAI compliant dictionary
|
||||||
embedding_response = []
|
embedding_response = []
|
||||||
|
input_tokens: int = 0
|
||||||
for idx, embedding in enumerate(embeddings):
|
for idx, embedding in enumerate(embeddings):
|
||||||
embedding_response.append(
|
embedding_response.append(
|
||||||
{
|
{
|
||||||
|
@ -1461,14 +1475,10 @@ def embedding(
|
||||||
"embedding": embedding.values,
|
"embedding": embedding.values,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
input_tokens += embedding.statistics.token_count
|
||||||
model_response["object"] = "list"
|
model_response["object"] = "list"
|
||||||
model_response["data"] = embedding_response
|
model_response["data"] = embedding_response
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
input_tokens = 0
|
|
||||||
|
|
||||||
input_str = "".join(input)
|
|
||||||
|
|
||||||
input_tokens += len(encoding.encode(input_str))
|
|
||||||
|
|
||||||
usage = Usage(
|
usage = Usage(
|
||||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||||
|
@ -1511,6 +1521,7 @@ async def async_embedding(
|
||||||
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
||||||
## Populate OpenAI compliant dictionary
|
## Populate OpenAI compliant dictionary
|
||||||
embedding_response = []
|
embedding_response = []
|
||||||
|
input_tokens: int = 0
|
||||||
for idx, embedding in enumerate(embeddings):
|
for idx, embedding in enumerate(embeddings):
|
||||||
embedding_response.append(
|
embedding_response.append(
|
||||||
{
|
{
|
||||||
|
@ -1519,18 +1530,13 @@ async def async_embedding(
|
||||||
"embedding": embedding.values,
|
"embedding": embedding.values,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
input_tokens += embedding.statistics.token_count
|
||||||
|
|
||||||
model_response["object"] = "list"
|
model_response["object"] = "list"
|
||||||
model_response["data"] = embedding_response
|
model_response["data"] = embedding_response
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
input_tokens = 0
|
|
||||||
|
|
||||||
input_str = "".join(input)
|
|
||||||
|
|
||||||
input_tokens += len(encoding.encode(input_str))
|
|
||||||
|
|
||||||
usage = Usage(
|
usage = Usage(
|
||||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||||
)
|
)
|
||||||
model_response.usage = usage
|
model_response.usage = usage
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue