diff --git a/litellm/__init__.py b/litellm/__init__.py index a627061cfe..591d5873cf 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -848,7 +848,7 @@ from .llms.gemini import GeminiConfig from .llms.nlp_cloud import NLPCloudConfig from .llms.aleph_alpha import AlephAlphaConfig from .llms.petals import PetalsConfig -from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( +from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexGeminiConfig, GoogleAIStudioGeminiConfig, VertexAIConfig, @@ -862,9 +862,6 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import ( from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import ( VertexAILlama3Config, ) -from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.ai21.transformation import ( - VertexAIAi21Config, -) from .llms.sagemaker.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig diff --git a/litellm/llms/fine_tuning_apis/vertex_ai.py b/litellm/llms/fine_tuning_apis/vertex_ai.py index e87a9bf3c4..618cf510af 100644 --- a/litellm/llms/fine_tuning_apis/vertex_ai.py +++ b/litellm/llms/fine_tuning_apis/vertex_ai.py @@ -8,7 +8,7 @@ from openai.types.fine_tuning.fine_tuning_job import FineTuningJob, Hyperparamet from litellm._logging import verbose_logger from litellm.llms.base import BaseLLM from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( +from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) from litellm.types.llms.openai import FineTuningJobCreate diff --git a/litellm/llms/text_to_speech/vertex_ai.py b/litellm/llms/text_to_speech/vertex_ai.py index b9fca53250..0aac32eb50 100644 --- a/litellm/llms/text_to_speech/vertex_ai.py +++ b/litellm/llms/text_to_speech/vertex_ai.py @@ -13,7 +13,7 @@ from litellm.llms.custom_httpx.http_handler import ( _get_httpx_client, ) from litellm.llms.openai import HttpxBinaryResponseContent -from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( +from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py index 7e2f9b29d0..d8607b4a8a 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py @@ -69,6 +69,9 @@ def _get_vertex_url( url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}" if stream is True: url += "?alt=sse" + elif mode == "embedding": + endpoint = "predict" + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" return url, endpoint @@ -79,8 +82,8 @@ def _get_gemini_url( stream: Optional[bool], gemini_api_key: Optional[str], ) -> Tuple[str, str]: + _gemini_model_name = "models/{}".format(model) if mode == "chat": - _gemini_model_name = "models/{}".format(model) endpoint = "generateContent" if stream is True: endpoint = "streamGenerateContent" @@ -94,5 +97,8 @@ def _get_gemini_url( ) ) elif mode == "embedding": - pass + endpoint = "embedContent" + url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( + _gemini_model_name, endpoint, gemini_api_key + ) return url, endpoint diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py index 944ae00bc9..d394cafd3a 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py @@ -11,8 +11,10 @@ from litellm.types.llms.vertex_ai import CachedContentRequestBody, SystemInstruc from litellm.utils import is_cached_message from ..common_utils import VertexAIError, get_supports_system_message -from ..gemini_transformation import transform_system_message -from ..vertex_and_google_ai_studio_gemini import _gemini_convert_messages_with_history +from ..gemini.transformation import transform_system_message +from ..gemini.vertex_and_google_ai_studio_gemini import ( + _gemini_convert_messages_with_history, +) def separate_cached_messages( diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_handler.py new file mode 100644 index 0000000000..bc0d4ac16f --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_handler.py @@ -0,0 +1,121 @@ +""" +Google AI Studio Embeddings Endpoint +""" + +import json +from typing import Literal, Optional, Union + +import httpx + +import litellm +from litellm import EmbeddingResponse +from litellm.llms.custom_httpx.http_handler import HTTPHandler + +from .vertex_and_google_ai_studio_gemini import VertexLLM + + +class GoogleEmbeddings(VertexLLM): + def text_embeddings( + self, + model: str, + input: Union[list, str], + print_verbose, + model_response: EmbeddingResponse, + custom_llm_provider: Literal["gemini", "vertex_ai"], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + logging_obj=None, + encoding=None, + vertex_project=None, + vertex_location=None, + vertex_credentials=None, + aembedding=False, + timeout=300, + client=None, + ) -> EmbeddingResponse: + return model_response + auth_header, url = self._get_token_and_url( + model=model, + gemini_api_key=api_key, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=None, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + should_use_v1beta1_features=False, + mode="embedding", + ) + + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore + else: + sync_handler = client # type: ignore + + optional_params = optional_params or {} + + # request_data = VertexMultimodalEmbeddingRequest() + + # if "instances" in optional_params: + # request_data["instances"] = optional_params["instances"] + # elif isinstance(input, list): + # request_data["instances"] = input + # else: + # # construct instances + # vertex_request_instance = Instance(**optional_params) + + # if isinstance(input, str): + # vertex_request_instance["text"] = input + + # request_data["instances"] = [vertex_request_instance] + + # headers = { + # "Content-Type": "application/json; charset=utf-8", + # "Authorization": f"Bearer {auth_header}", + # } + + # ## LOGGING + # logging_obj.pre_call( + # input=input, + # api_key="", + # additional_args={ + # "complete_input_dict": request_data, + # "api_base": url, + # "headers": headers, + # }, + # ) + + # if aembedding is True: + # pass + + # response = sync_handler.post( + # url=url, + # headers=headers, + # data=json.dumps(request_data), + # ) + + # if response.status_code != 200: + # raise Exception(f"Error: {response.status_code} {response.text}") + + # _json_response = response.json() + # if "predictions" not in _json_response: + # raise litellm.InternalServerError( + # message=f"embedding response does not contain 'predictions', got {_json_response}", + # llm_provider="vertex_ai", + # model=model, + # ) + # _predictions = _json_response["predictions"] + + # model_response.data = _predictions + # model_response.model = model + + # return model_response diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_transformation.py new file mode 100644 index 0000000000..2e3d156f51 --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_transformation.py @@ -0,0 +1,5 @@ +""" +Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /embedContent format. + +Why separate file? Make it easy to see how transformation works +""" diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py index d897f5bfbd..819a94cb0c 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py @@ -1813,6 +1813,7 @@ class VertexLLM(BaseLLM): custom_llm_provider=custom_llm_provider, api_base=api_base, should_use_v1beta1_features=False, + mode="embedding", ) if client is None: @@ -1828,11 +1829,6 @@ class VertexLLM(BaseLLM): else: sync_handler = client # type: ignore - url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" - - auth_header, _ = self._ensure_access_token( - credentials=vertex_credentials, project_id=vertex_project - ) optional_params = optional_params or {} request_data = VertexMultimodalEmbeddingRequest() @@ -1850,30 +1846,22 @@ class VertexLLM(BaseLLM): request_data["instances"] = [vertex_request_instance] - request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" - logging_obj.pre_call( - input=[], - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - logging_obj.pre_call( - input=[], - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - headers = { "Content-Type": "application/json; charset=utf-8", "Authorization": f"Bearer {auth_header}", } + ## LOGGING + logging_obj.pre_call( + input=input, + api_key="", + additional_args={ + "complete_input_dict": request_data, + "api_base": url, + "headers": headers, + }, + ) + if aembedding is True: return self.async_multimodal_embedding( model=model, diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_anthropic.py index b13b87bc67..e85160a43c 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_anthropic.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_anthropic.py @@ -205,7 +205,7 @@ def get_vertex_client( vertex_credentials: Optional[str], ) -> Tuple[Any, Optional[str]]: args = locals() - from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( + from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) @@ -270,7 +270,7 @@ def completion( from anthropic import AnthropicVertex from litellm.llms.anthropic import AnthropicChatCompletion - from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( + from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) except: diff --git a/litellm/main.py b/litellm/main.py index b83a583f4a..8896f1faf1 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3134,6 +3134,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse: or custom_llm_provider == "fireworks_ai" or custom_llm_provider == "ollama" or custom_llm_provider == "vertex_ai" + or custom_llm_provider == "gemini" or custom_llm_provider == "databricks" or custom_llm_provider == "watsonx" or custom_llm_provider == "cohere" @@ -3528,6 +3529,26 @@ def embedding( client=client, aembedding=aembedding, ) + elif custom_llm_provider == "gemini": + + gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key + + response = vertex_chat_completion.multimodal_embedding( # type: ignore + model=model, + input=input, + encoding=encoding, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + vertex_project=None, + vertex_location=None, + vertex_credentials=None, + aembedding=aembedding, + print_verbose=print_verbose, + custom_llm_provider="gemini", + api_key=gemini_api_key, + ) + elif custom_llm_provider == "vertex_ai": vertex_ai_project = ( optional_params.pop("vertex_project", None) diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index 31268395f1..2fbc70f024 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -686,6 +686,22 @@ async def test_triton_embeddings(): pytest.fail(f"Error occurred: {e}") +@pytest.mark.asyncio +async def test_gemini_embeddings(): + try: + litellm.set_verbose = True + response = await litellm.aembedding( + model="gemini/text-embedding-004", + input=["good morning from litellm"], + ) + print(f"response: {response}") + + # stubbed endpoint is setup to return this + assert response.data[0]["embedding"] == [0.1, 0.2] + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio async def test_databricks_embeddings(sync_mode):