diff --git a/litellm/__init__.py b/litellm/__init__.py index 0021daef0b..36fac67373 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -234,6 +234,7 @@ vertex_chat_models: List = [] vertex_code_chat_models: List = [] vertex_text_models: List = [] vertex_code_text_models: List = [] +vertex_embedding_models: List = [] ai21_models: List = [] nlp_cloud_models: List = [] aleph_alpha_models: List = [] @@ -263,6 +264,8 @@ for key, value in model_cost.items(): vertex_chat_models.append(key) elif value.get("litellm_provider") == "vertex_ai-code-chat-models": vertex_code_chat_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-embedding-models": + vertex_embedding_models.append(key) elif value.get("litellm_provider") == "ai21": ai21_models.append(key) elif value.get("litellm_provider") == "nlp_cloud": @@ -499,7 +502,10 @@ bedrock_embedding_models: List = [ ] all_embedding_models = ( - open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models + open_ai_embedding_models + + cohere_embedding_models + + bedrock_embedding_models + + vertex_embedding_models ) ####### IMAGE GENERATION MODELS ################### diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 9965c037a5..56f50c1190 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -3,7 +3,7 @@ import json from enum import Enum import requests import time -from typing import Callable, Optional +from typing import Callable, Optional, Union from litellm.utils import ModelResponse, Usage, CustomStreamWrapper import litellm, uuid import httpx @@ -935,6 +935,68 @@ async def async_streaming( return streamwrapper -def embedding(): +def embedding( + model: str, + input: Union[list, str], + api_key: Optional[str] = None, + logging_obj=None, + model_response=None, + optional_params=None, + encoding=None, + vertex_project=None, + vertex_location=None, +): # logic for parsing in - calling - parsing out model embedding calls - pass + try: + import vertexai + except: + raise VertexAIError( + status_code=400, + message="vertexai import failed please run `pip install google-cloud-aiplatform`", + ) + + from vertexai.language_models import TextEmbeddingModel + import google.auth + + ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 + try: + creds, _ = google.auth.default(quota_project_id=vertex_project) + vertexai.init( + project=vertex_project, location=vertex_location, credentials=creds + ) + except Exception as e: + raise VertexAIError(status_code=401, message=str(e)) + + if isinstance(input, str): + input = [input] + + try: + llm_model = TextEmbeddingModel.from_pretrained(model) + embeddings = llm_model.get_embeddings(input) + except Exception as e: + raise VertexAIError(status_code=500, message=str(e)) + ## Populate OpenAI compliant dictionary + embedding_response = [] + for idx, embedding in enumerate(embeddings): + embedding_response.append( + { + "object": "embedding", + "index": idx, + "embedding": embedding.values, + } + ) + model_response["object"] = "list" + model_response["data"] = embedding_response + model_response["model"] = model + input_tokens = 0 + + input_str = "".join(input) + + input_tokens += len(encoding.encode(input_str)) + + usage = Usage( + prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + ) + model_response.usage = usage + + return model_response diff --git a/litellm/main.py b/litellm/main.py index a30d6a8e41..9a24db87ab 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2486,7 +2486,7 @@ def embedding( client=client, aembedding=aembedding, ) - elif model in litellm.cohere_embedding_models: + elif custom_llm_provider == "cohere": cohere_key = ( api_key or litellm.cohere_key @@ -2528,6 +2528,28 @@ def embedding( optional_params=optional_params, model_response=EmbeddingResponse(), ) + elif custom_llm_provider == "vertex_ai": + vertex_ai_project = ( + optional_params.pop("vertex_ai_project", None) + or litellm.vertex_project + or get_secret("VERTEXAI_PROJECT") + ) + vertex_ai_location = ( + optional_params.pop("vertex_ai_location", None) + or litellm.vertex_location + or get_secret("VERTEXAI_LOCATION") + ) + + response = vertex_ai.embedding( + model=model, + input=input, + encoding=encoding, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + vertex_project=vertex_ai_project, + vertex_location=vertex_ai_location, + ) elif custom_llm_provider == "oobabooga": response = oobabooga.embedding( model=model, diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index 681471e3db..9b3b34bfef 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -231,6 +231,19 @@ def test_cohere_embedding3(): # test_cohere_embedding3() +def test_vertexai_embedding(): + try: + # litellm.set_verbose=True + response = embedding( + model="textembedding-gecko@001", + input=["good morning from litellm", "this is another item"], + ) + print(f"response:", response) + raise Exception("it worked!") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_bedrock_embedding_titan(): try: # this tests if we support str input for bedrock embedding diff --git a/litellm/utils.py b/litellm/utils.py index f265a01906..2bab08876a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4538,6 +4538,7 @@ def get_llm_provider( or model in litellm.vertex_text_models or model in litellm.vertex_code_text_models or model in litellm.vertex_language_models + or model in litellm.vertex_embedding_models ): custom_llm_provider = "vertex_ai" ## ai21 diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 3f92c61d2c..7c2b1f6bb1 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -612,6 +612,51 @@ "litellm_provider": "vertex_ai-vision-models", "mode": "chat" }, + "textembedding-gecko": { + "max_tokens": 3072, + "max_input_tokens": 3072, + "output_vector_size": 768, + "input_cost_per_token": 0.00000000625, + "output_cost_per_token": 0, + "litellm_provider": "vertex_ai-embedding-models", + "mode": "embedding" + }, + "textembedding-gecko-multilingual": { + "max_tokens": 3072, + "max_input_tokens": 3072, + "output_vector_size": 768, + "input_cost_per_token": 0.00000000625, + "output_cost_per_token": 0, + "litellm_provider": "vertex_ai-embedding-models", + "mode": "embedding" + }, + "textembedding-gecko-multilingual@001": { + "max_tokens": 3072, + "max_input_tokens": 3072, + "output_vector_size": 768, + "input_cost_per_token": 0.00000000625, + "output_cost_per_token": 0, + "litellm_provider": "vertex_ai-embedding-models", + "mode": "embedding" + }, + "textembedding-gecko@001": { + "max_tokens": 3072, + "max_input_tokens": 3072, + "output_vector_size": 768, + "input_cost_per_token": 0.00000000625, + "output_cost_per_token": 0, + "litellm_provider": "vertex_ai-embedding-models", + "mode": "embedding" + }, + "textembedding-gecko@003": { + "max_tokens": 3072, + "max_input_tokens": 3072, + "output_vector_size": 768, + "input_cost_per_token": 0.00000000625, + "output_cost_per_token": 0, + "litellm_provider": "vertex_ai-embedding-models", + "mode": "embedding" + }, "palm/chat-bison": { "max_tokens": 4096, "input_cost_per_token": 0.000000125,