mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(vertex_ai.py): vertex ai gecko text embedding support
This commit is contained in:
parent
6cdb9aede0
commit
d9ba8668f4
6 changed files with 154 additions and 5 deletions
|
@ -3,7 +3,7 @@ import json
|
|||
from enum import Enum
|
||||
import requests
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Union
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
||||
import litellm, uuid
|
||||
import httpx
|
||||
|
@ -935,6 +935,68 @@ async def async_streaming(
|
|||
return streamwrapper
|
||||
|
||||
|
||||
def embedding():
|
||||
def embedding(
|
||||
model: str,
|
||||
input: Union[list, str],
|
||||
api_key: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
optional_params=None,
|
||||
encoding=None,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
):
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
try:
|
||||
import vertexai
|
||||
except:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
|
||||
)
|
||||
|
||||
from vertexai.language_models import TextEmbeddingModel
|
||||
import google.auth
|
||||
|
||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||
try:
|
||||
creds, _ = google.auth.default(quota_project_id=vertex_project)
|
||||
vertexai.init(
|
||||
project=vertex_project, location=vertex_location, credentials=creds
|
||||
)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=401, message=str(e))
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
try:
|
||||
llm_model = TextEmbeddingModel.from_pretrained(model)
|
||||
embeddings = llm_model.get_embeddings(input)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
## Populate OpenAI compliant dictionary
|
||||
embedding_response = []
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
embedding_response.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding.values,
|
||||
}
|
||||
)
|
||||
model_response["object"] = "list"
|
||||
model_response["data"] = embedding_response
|
||||
model_response["model"] = model
|
||||
input_tokens = 0
|
||||
|
||||
input_str = "".join(input)
|
||||
|
||||
input_tokens += len(encoding.encode(input_str))
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||
)
|
||||
model_response.usage = usage
|
||||
|
||||
return model_response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue