feat(vertex_ai.py): vertex ai gecko text embedding support

This commit is contained in:
Krrish Dholakia 2024-02-03 09:48:29 -08:00
parent 6cdb9aede0
commit d9ba8668f4
6 changed files with 154 additions and 5 deletions

View file

@ -234,6 +234,7 @@ vertex_chat_models: List = []
vertex_code_chat_models: List = []
vertex_text_models: List = []
vertex_code_text_models: List = []
vertex_embedding_models: List = []
ai21_models: List = []
nlp_cloud_models: List = []
aleph_alpha_models: List = []
@ -263,6 +264,8 @@ for key, value in model_cost.items():
vertex_chat_models.append(key)
elif value.get("litellm_provider") == "vertex_ai-code-chat-models":
vertex_code_chat_models.append(key)
elif value.get("litellm_provider") == "vertex_ai-embedding-models":
vertex_embedding_models.append(key)
elif value.get("litellm_provider") == "ai21":
ai21_models.append(key)
elif value.get("litellm_provider") == "nlp_cloud":
@ -499,7 +502,10 @@ bedrock_embedding_models: List = [
]
all_embedding_models = (
open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
open_ai_embedding_models
+ cohere_embedding_models
+ bedrock_embedding_models
+ vertex_embedding_models
)
####### IMAGE GENERATION MODELS ###################

View file

@ -3,7 +3,7 @@ import json
from enum import Enum
import requests
import time
from typing import Callable, Optional
from typing import Callable, Optional, Union
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm, uuid
import httpx
@ -935,6 +935,68 @@ async def async_streaming(
return streamwrapper
def embedding():
def embedding(
model: str,
input: Union[list, str],
api_key: Optional[str] = None,
logging_obj=None,
model_response=None,
optional_params=None,
encoding=None,
vertex_project=None,
vertex_location=None,
):
# logic for parsing in - calling - parsing out model embedding calls
pass
try:
import vertexai
except:
raise VertexAIError(
status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
)
from vertexai.language_models import TextEmbeddingModel
import google.auth
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
try:
creds, _ = google.auth.default(quota_project_id=vertex_project)
vertexai.init(
project=vertex_project, location=vertex_location, credentials=creds
)
except Exception as e:
raise VertexAIError(status_code=401, message=str(e))
if isinstance(input, str):
input = [input]
try:
llm_model = TextEmbeddingModel.from_pretrained(model)
embeddings = llm_model.get_embeddings(input)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
## Populate OpenAI compliant dictionary
embedding_response = []
for idx, embedding in enumerate(embeddings):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding.values,
}
)
model_response["object"] = "list"
model_response["data"] = embedding_response
model_response["model"] = model
input_tokens = 0
input_str = "".join(input)
input_tokens += len(encoding.encode(input_str))
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
model_response.usage = usage
return model_response

View file

@ -2486,7 +2486,7 @@ def embedding(
client=client,
aembedding=aembedding,
)
elif model in litellm.cohere_embedding_models:
elif custom_llm_provider == "cohere":
cohere_key = (
api_key
or litellm.cohere_key
@ -2528,6 +2528,28 @@ def embedding(
optional_params=optional_params,
model_response=EmbeddingResponse(),
)
elif custom_llm_provider == "vertex_ai":
vertex_ai_project = (
optional_params.pop("vertex_ai_project", None)
or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.pop("vertex_ai_location", None)
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
)
response = vertex_ai.embedding(
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
)
elif custom_llm_provider == "oobabooga":
response = oobabooga.embedding(
model=model,

View file

@ -231,6 +231,19 @@ def test_cohere_embedding3():
# test_cohere_embedding3()
def test_vertexai_embedding():
try:
# litellm.set_verbose=True
response = embedding(
model="textembedding-gecko@001",
input=["good morning from litellm", "this is another item"],
)
print(f"response:", response)
raise Exception("it worked!")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_bedrock_embedding_titan():
try:
# this tests if we support str input for bedrock embedding

View file

@ -4538,6 +4538,7 @@ def get_llm_provider(
or model in litellm.vertex_text_models
or model in litellm.vertex_code_text_models
or model in litellm.vertex_language_models
or model in litellm.vertex_embedding_models
):
custom_llm_provider = "vertex_ai"
## ai21

View file

@ -612,6 +612,51 @@
"litellm_provider": "vertex_ai-vision-models",
"mode": "chat"
},
"textembedding-gecko": {
"max_tokens": 3072,
"max_input_tokens": 3072,
"output_vector_size": 768,
"input_cost_per_token": 0.00000000625,
"output_cost_per_token": 0,
"litellm_provider": "vertex_ai-embedding-models",
"mode": "embedding"
},
"textembedding-gecko-multilingual": {
"max_tokens": 3072,
"max_input_tokens": 3072,
"output_vector_size": 768,
"input_cost_per_token": 0.00000000625,
"output_cost_per_token": 0,
"litellm_provider": "vertex_ai-embedding-models",
"mode": "embedding"
},
"textembedding-gecko-multilingual@001": {
"max_tokens": 3072,
"max_input_tokens": 3072,
"output_vector_size": 768,
"input_cost_per_token": 0.00000000625,
"output_cost_per_token": 0,
"litellm_provider": "vertex_ai-embedding-models",
"mode": "embedding"
},
"textembedding-gecko@001": {
"max_tokens": 3072,
"max_input_tokens": 3072,
"output_vector_size": 768,
"input_cost_per_token": 0.00000000625,
"output_cost_per_token": 0,
"litellm_provider": "vertex_ai-embedding-models",
"mode": "embedding"
},
"textembedding-gecko@003": {
"max_tokens": 3072,
"max_input_tokens": 3072,
"output_vector_size": 768,
"input_cost_per_token": 0.00000000625,
"output_cost_per_token": 0,
"litellm_provider": "vertex_ai-embedding-models",
"mode": "embedding"
},
"palm/chat-bison": {
"max_tokens": 4096,
"input_cost_per_token": 0.000000125,