feat(vertex_ai.py): vertex ai gecko text embedding support

This commit is contained in:
Krrish Dholakia 2024-02-03 09:48:29 -08:00
parent 6cdb9aede0
commit d9ba8668f4
6 changed files with 154 additions and 5 deletions

View file

@ -3,7 +3,7 @@ import json
from enum import Enum
import requests
import time
from typing import Callable, Optional
from typing import Callable, Optional, Union
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm, uuid
import httpx
@ -935,6 +935,68 @@ async def async_streaming(
return streamwrapper
def embedding():
def embedding(
model: str,
input: Union[list, str],
api_key: Optional[str] = None,
logging_obj=None,
model_response=None,
optional_params=None,
encoding=None,
vertex_project=None,
vertex_location=None,
):
# logic for parsing in - calling - parsing out model embedding calls
pass
try:
import vertexai
except:
raise VertexAIError(
status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
)
from vertexai.language_models import TextEmbeddingModel
import google.auth
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
try:
creds, _ = google.auth.default(quota_project_id=vertex_project)
vertexai.init(
project=vertex_project, location=vertex_location, credentials=creds
)
except Exception as e:
raise VertexAIError(status_code=401, message=str(e))
if isinstance(input, str):
input = [input]
try:
llm_model = TextEmbeddingModel.from_pretrained(model)
embeddings = llm_model.get_embeddings(input)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
## Populate OpenAI compliant dictionary
embedding_response = []
for idx, embedding in enumerate(embeddings):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding.values,
}
)
model_response["object"] = "list"
model_response["data"] = embedding_response
model_response["model"] = model
input_tokens = 0
input_str = "".join(input)
input_tokens += len(encoding.encode(input_str))
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
model_response.usage = usage
return model_response