feat(vertex_ai.py): vertex ai gecko text embedding support

This commit is contained in:
Krrish Dholakia 2024-02-03 09:48:29 -08:00
parent 6cdb9aede0
commit d9ba8668f4
6 changed files with 154 additions and 5 deletions

View file

@ -234,6 +234,7 @@ vertex_chat_models: List = []
vertex_code_chat_models: List = [] vertex_code_chat_models: List = []
vertex_text_models: List = [] vertex_text_models: List = []
vertex_code_text_models: List = [] vertex_code_text_models: List = []
vertex_embedding_models: List = []
ai21_models: List = [] ai21_models: List = []
nlp_cloud_models: List = [] nlp_cloud_models: List = []
aleph_alpha_models: List = [] aleph_alpha_models: List = []
@ -263,6 +264,8 @@ for key, value in model_cost.items():
vertex_chat_models.append(key) vertex_chat_models.append(key)
elif value.get("litellm_provider") == "vertex_ai-code-chat-models": elif value.get("litellm_provider") == "vertex_ai-code-chat-models":
vertex_code_chat_models.append(key) vertex_code_chat_models.append(key)
elif value.get("litellm_provider") == "vertex_ai-embedding-models":
vertex_embedding_models.append(key)
elif value.get("litellm_provider") == "ai21": elif value.get("litellm_provider") == "ai21":
ai21_models.append(key) ai21_models.append(key)
elif value.get("litellm_provider") == "nlp_cloud": elif value.get("litellm_provider") == "nlp_cloud":
@ -499,7 +502,10 @@ bedrock_embedding_models: List = [
] ]
all_embedding_models = ( all_embedding_models = (
open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models open_ai_embedding_models
+ cohere_embedding_models
+ bedrock_embedding_models
+ vertex_embedding_models
) )
####### IMAGE GENERATION MODELS ################### ####### IMAGE GENERATION MODELS ###################

View file

@ -3,7 +3,7 @@ import json
from enum import Enum from enum import Enum
import requests import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional, Union
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm, uuid import litellm, uuid
import httpx import httpx
@ -935,6 +935,68 @@ async def async_streaming(
return streamwrapper return streamwrapper
def embedding(): def embedding(
model: str,
input: Union[list, str],
api_key: Optional[str] = None,
logging_obj=None,
model_response=None,
optional_params=None,
encoding=None,
vertex_project=None,
vertex_location=None,
):
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass try:
import vertexai
except:
raise VertexAIError(
status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
)
from vertexai.language_models import TextEmbeddingModel
import google.auth
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
try:
creds, _ = google.auth.default(quota_project_id=vertex_project)
vertexai.init(
project=vertex_project, location=vertex_location, credentials=creds
)
except Exception as e:
raise VertexAIError(status_code=401, message=str(e))
if isinstance(input, str):
input = [input]
try:
llm_model = TextEmbeddingModel.from_pretrained(model)
embeddings = llm_model.get_embeddings(input)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
## Populate OpenAI compliant dictionary
embedding_response = []
for idx, embedding in enumerate(embeddings):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding.values,
}
)
model_response["object"] = "list"
model_response["data"] = embedding_response
model_response["model"] = model
input_tokens = 0
input_str = "".join(input)
input_tokens += len(encoding.encode(input_str))
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
model_response.usage = usage
return model_response

View file

@ -2486,7 +2486,7 @@ def embedding(
client=client, client=client,
aembedding=aembedding, aembedding=aembedding,
) )
elif model in litellm.cohere_embedding_models: elif custom_llm_provider == "cohere":
cohere_key = ( cohere_key = (
api_key api_key
or litellm.cohere_key or litellm.cohere_key
@ -2528,6 +2528,28 @@ def embedding(
optional_params=optional_params, optional_params=optional_params,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
) )
elif custom_llm_provider == "vertex_ai":
vertex_ai_project = (
optional_params.pop("vertex_ai_project", None)
or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.pop("vertex_ai_location", None)
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
)
response = vertex_ai.embedding(
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
)
elif custom_llm_provider == "oobabooga": elif custom_llm_provider == "oobabooga":
response = oobabooga.embedding( response = oobabooga.embedding(
model=model, model=model,

View file

@ -231,6 +231,19 @@ def test_cohere_embedding3():
# test_cohere_embedding3() # test_cohere_embedding3()
def test_vertexai_embedding():
try:
# litellm.set_verbose=True
response = embedding(
model="textembedding-gecko@001",
input=["good morning from litellm", "this is another item"],
)
print(f"response:", response)
raise Exception("it worked!")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_bedrock_embedding_titan(): def test_bedrock_embedding_titan():
try: try:
# this tests if we support str input for bedrock embedding # this tests if we support str input for bedrock embedding

View file

@ -4538,6 +4538,7 @@ def get_llm_provider(
or model in litellm.vertex_text_models or model in litellm.vertex_text_models
or model in litellm.vertex_code_text_models or model in litellm.vertex_code_text_models
or model in litellm.vertex_language_models or model in litellm.vertex_language_models
or model in litellm.vertex_embedding_models
): ):
custom_llm_provider = "vertex_ai" custom_llm_provider = "vertex_ai"
## ai21 ## ai21

View file

@ -612,6 +612,51 @@
"litellm_provider": "vertex_ai-vision-models", "litellm_provider": "vertex_ai-vision-models",
"mode": "chat" "mode": "chat"
}, },
"textembedding-gecko": {
"max_tokens": 3072,
"max_input_tokens": 3072,
"output_vector_size": 768,
"input_cost_per_token": 0.00000000625,
"output_cost_per_token": 0,
"litellm_provider": "vertex_ai-embedding-models",
"mode": "embedding"
},
"textembedding-gecko-multilingual": {
"max_tokens": 3072,
"max_input_tokens": 3072,
"output_vector_size": 768,
"input_cost_per_token": 0.00000000625,
"output_cost_per_token": 0,
"litellm_provider": "vertex_ai-embedding-models",
"mode": "embedding"
},
"textembedding-gecko-multilingual@001": {
"max_tokens": 3072,
"max_input_tokens": 3072,
"output_vector_size": 768,
"input_cost_per_token": 0.00000000625,
"output_cost_per_token": 0,
"litellm_provider": "vertex_ai-embedding-models",
"mode": "embedding"
},
"textembedding-gecko@001": {
"max_tokens": 3072,
"max_input_tokens": 3072,
"output_vector_size": 768,
"input_cost_per_token": 0.00000000625,
"output_cost_per_token": 0,
"litellm_provider": "vertex_ai-embedding-models",
"mode": "embedding"
},
"textembedding-gecko@003": {
"max_tokens": 3072,
"max_input_tokens": 3072,
"output_vector_size": 768,
"input_cost_per_token": 0.00000000625,
"output_cost_per_token": 0,
"litellm_provider": "vertex_ai-embedding-models",
"mode": "embedding"
},
"palm/chat-bison": { "palm/chat-bison": {
"max_tokens": 4096, "max_tokens": 4096,
"input_cost_per_token": 0.000000125, "input_cost_per_token": 0.000000125,