mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(vertex_ai.py): vertex ai gecko text embedding support
This commit is contained in:
parent
6cdb9aede0
commit
d9ba8668f4
6 changed files with 154 additions and 5 deletions
|
@ -234,6 +234,7 @@ vertex_chat_models: List = []
|
||||||
vertex_code_chat_models: List = []
|
vertex_code_chat_models: List = []
|
||||||
vertex_text_models: List = []
|
vertex_text_models: List = []
|
||||||
vertex_code_text_models: List = []
|
vertex_code_text_models: List = []
|
||||||
|
vertex_embedding_models: List = []
|
||||||
ai21_models: List = []
|
ai21_models: List = []
|
||||||
nlp_cloud_models: List = []
|
nlp_cloud_models: List = []
|
||||||
aleph_alpha_models: List = []
|
aleph_alpha_models: List = []
|
||||||
|
@ -263,6 +264,8 @@ for key, value in model_cost.items():
|
||||||
vertex_chat_models.append(key)
|
vertex_chat_models.append(key)
|
||||||
elif value.get("litellm_provider") == "vertex_ai-code-chat-models":
|
elif value.get("litellm_provider") == "vertex_ai-code-chat-models":
|
||||||
vertex_code_chat_models.append(key)
|
vertex_code_chat_models.append(key)
|
||||||
|
elif value.get("litellm_provider") == "vertex_ai-embedding-models":
|
||||||
|
vertex_embedding_models.append(key)
|
||||||
elif value.get("litellm_provider") == "ai21":
|
elif value.get("litellm_provider") == "ai21":
|
||||||
ai21_models.append(key)
|
ai21_models.append(key)
|
||||||
elif value.get("litellm_provider") == "nlp_cloud":
|
elif value.get("litellm_provider") == "nlp_cloud":
|
||||||
|
@ -499,7 +502,10 @@ bedrock_embedding_models: List = [
|
||||||
]
|
]
|
||||||
|
|
||||||
all_embedding_models = (
|
all_embedding_models = (
|
||||||
open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
|
open_ai_embedding_models
|
||||||
|
+ cohere_embedding_models
|
||||||
|
+ bedrock_embedding_models
|
||||||
|
+ vertex_embedding_models
|
||||||
)
|
)
|
||||||
|
|
||||||
####### IMAGE GENERATION MODELS ###################
|
####### IMAGE GENERATION MODELS ###################
|
||||||
|
|
|
@ -3,7 +3,7 @@ import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional, Union
|
||||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -935,6 +935,68 @@ async def async_streaming(
|
||||||
return streamwrapper
|
return streamwrapper
|
||||||
|
|
||||||
|
|
||||||
def embedding():
|
def embedding(
|
||||||
|
model: str,
|
||||||
|
input: Union[list, str],
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
logging_obj=None,
|
||||||
|
model_response=None,
|
||||||
|
optional_params=None,
|
||||||
|
encoding=None,
|
||||||
|
vertex_project=None,
|
||||||
|
vertex_location=None,
|
||||||
|
):
|
||||||
# logic for parsing in - calling - parsing out model embedding calls
|
# logic for parsing in - calling - parsing out model embedding calls
|
||||||
pass
|
try:
|
||||||
|
import vertexai
|
||||||
|
except:
|
||||||
|
raise VertexAIError(
|
||||||
|
status_code=400,
|
||||||
|
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
|
||||||
|
)
|
||||||
|
|
||||||
|
from vertexai.language_models import TextEmbeddingModel
|
||||||
|
import google.auth
|
||||||
|
|
||||||
|
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||||
|
try:
|
||||||
|
creds, _ = google.auth.default(quota_project_id=vertex_project)
|
||||||
|
vertexai.init(
|
||||||
|
project=vertex_project, location=vertex_location, credentials=creds
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise VertexAIError(status_code=401, message=str(e))
|
||||||
|
|
||||||
|
if isinstance(input, str):
|
||||||
|
input = [input]
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm_model = TextEmbeddingModel.from_pretrained(model)
|
||||||
|
embeddings = llm_model.get_embeddings(input)
|
||||||
|
except Exception as e:
|
||||||
|
raise VertexAIError(status_code=500, message=str(e))
|
||||||
|
## Populate OpenAI compliant dictionary
|
||||||
|
embedding_response = []
|
||||||
|
for idx, embedding in enumerate(embeddings):
|
||||||
|
embedding_response.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding.values,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
model_response["object"] = "list"
|
||||||
|
model_response["data"] = embedding_response
|
||||||
|
model_response["model"] = model
|
||||||
|
input_tokens = 0
|
||||||
|
|
||||||
|
input_str = "".join(input)
|
||||||
|
|
||||||
|
input_tokens += len(encoding.encode(input_str))
|
||||||
|
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||||
|
)
|
||||||
|
model_response.usage = usage
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
|
@ -2486,7 +2486,7 @@ def embedding(
|
||||||
client=client,
|
client=client,
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
)
|
)
|
||||||
elif model in litellm.cohere_embedding_models:
|
elif custom_llm_provider == "cohere":
|
||||||
cohere_key = (
|
cohere_key = (
|
||||||
api_key
|
api_key
|
||||||
or litellm.cohere_key
|
or litellm.cohere_key
|
||||||
|
@ -2528,6 +2528,28 @@ def embedding(
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "vertex_ai":
|
||||||
|
vertex_ai_project = (
|
||||||
|
optional_params.pop("vertex_ai_project", None)
|
||||||
|
or litellm.vertex_project
|
||||||
|
or get_secret("VERTEXAI_PROJECT")
|
||||||
|
)
|
||||||
|
vertex_ai_location = (
|
||||||
|
optional_params.pop("vertex_ai_location", None)
|
||||||
|
or litellm.vertex_location
|
||||||
|
or get_secret("VERTEXAI_LOCATION")
|
||||||
|
)
|
||||||
|
|
||||||
|
response = vertex_ai.embedding(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model_response=EmbeddingResponse(),
|
||||||
|
vertex_project=vertex_ai_project,
|
||||||
|
vertex_location=vertex_ai_location,
|
||||||
|
)
|
||||||
elif custom_llm_provider == "oobabooga":
|
elif custom_llm_provider == "oobabooga":
|
||||||
response = oobabooga.embedding(
|
response = oobabooga.embedding(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -231,6 +231,19 @@ def test_cohere_embedding3():
|
||||||
# test_cohere_embedding3()
|
# test_cohere_embedding3()
|
||||||
|
|
||||||
|
|
||||||
|
def test_vertexai_embedding():
|
||||||
|
try:
|
||||||
|
# litellm.set_verbose=True
|
||||||
|
response = embedding(
|
||||||
|
model="textembedding-gecko@001",
|
||||||
|
input=["good morning from litellm", "this is another item"],
|
||||||
|
)
|
||||||
|
print(f"response:", response)
|
||||||
|
raise Exception("it worked!")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_bedrock_embedding_titan():
|
def test_bedrock_embedding_titan():
|
||||||
try:
|
try:
|
||||||
# this tests if we support str input for bedrock embedding
|
# this tests if we support str input for bedrock embedding
|
||||||
|
|
|
@ -4538,6 +4538,7 @@ def get_llm_provider(
|
||||||
or model in litellm.vertex_text_models
|
or model in litellm.vertex_text_models
|
||||||
or model in litellm.vertex_code_text_models
|
or model in litellm.vertex_code_text_models
|
||||||
or model in litellm.vertex_language_models
|
or model in litellm.vertex_language_models
|
||||||
|
or model in litellm.vertex_embedding_models
|
||||||
):
|
):
|
||||||
custom_llm_provider = "vertex_ai"
|
custom_llm_provider = "vertex_ai"
|
||||||
## ai21
|
## ai21
|
||||||
|
|
|
@ -612,6 +612,51 @@
|
||||||
"litellm_provider": "vertex_ai-vision-models",
|
"litellm_provider": "vertex_ai-vision-models",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"textembedding-gecko": {
|
||||||
|
"max_tokens": 3072,
|
||||||
|
"max_input_tokens": 3072,
|
||||||
|
"output_vector_size": 768,
|
||||||
|
"input_cost_per_token": 0.00000000625,
|
||||||
|
"output_cost_per_token": 0,
|
||||||
|
"litellm_provider": "vertex_ai-embedding-models",
|
||||||
|
"mode": "embedding"
|
||||||
|
},
|
||||||
|
"textembedding-gecko-multilingual": {
|
||||||
|
"max_tokens": 3072,
|
||||||
|
"max_input_tokens": 3072,
|
||||||
|
"output_vector_size": 768,
|
||||||
|
"input_cost_per_token": 0.00000000625,
|
||||||
|
"output_cost_per_token": 0,
|
||||||
|
"litellm_provider": "vertex_ai-embedding-models",
|
||||||
|
"mode": "embedding"
|
||||||
|
},
|
||||||
|
"textembedding-gecko-multilingual@001": {
|
||||||
|
"max_tokens": 3072,
|
||||||
|
"max_input_tokens": 3072,
|
||||||
|
"output_vector_size": 768,
|
||||||
|
"input_cost_per_token": 0.00000000625,
|
||||||
|
"output_cost_per_token": 0,
|
||||||
|
"litellm_provider": "vertex_ai-embedding-models",
|
||||||
|
"mode": "embedding"
|
||||||
|
},
|
||||||
|
"textembedding-gecko@001": {
|
||||||
|
"max_tokens": 3072,
|
||||||
|
"max_input_tokens": 3072,
|
||||||
|
"output_vector_size": 768,
|
||||||
|
"input_cost_per_token": 0.00000000625,
|
||||||
|
"output_cost_per_token": 0,
|
||||||
|
"litellm_provider": "vertex_ai-embedding-models",
|
||||||
|
"mode": "embedding"
|
||||||
|
},
|
||||||
|
"textembedding-gecko@003": {
|
||||||
|
"max_tokens": 3072,
|
||||||
|
"max_input_tokens": 3072,
|
||||||
|
"output_vector_size": 768,
|
||||||
|
"input_cost_per_token": 0.00000000625,
|
||||||
|
"output_cost_per_token": 0,
|
||||||
|
"litellm_provider": "vertex_ai-embedding-models",
|
||||||
|
"mode": "embedding"
|
||||||
|
},
|
||||||
"palm/chat-bison": {
|
"palm/chat-bison": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"input_cost_per_token": 0.000000125,
|
"input_cost_per_token": 0.000000125,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue