fix: initial commit

This commit is contained in:
Krrish Dholakia 2024-08-27 17:35:56 -07:00
parent d29a7087f1
commit 77e6da78a1
11 changed files with 192 additions and 36 deletions

View file

@ -848,7 +848,7 @@ from .llms.gemini import GeminiConfig
from .llms.nlp_cloud import NLPCloudConfig from .llms.nlp_cloud import NLPCloudConfig
from .llms.aleph_alpha import AlephAlphaConfig from .llms.aleph_alpha import AlephAlphaConfig
from .llms.petals import PetalsConfig from .llms.petals import PetalsConfig
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig, VertexGeminiConfig,
GoogleAIStudioGeminiConfig, GoogleAIStudioGeminiConfig,
VertexAIConfig, VertexAIConfig,
@ -862,9 +862,6 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import ( from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import (
VertexAILlama3Config, VertexAILlama3Config,
) )
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.ai21.transformation import (
VertexAIAi21Config,
)
from .llms.sagemaker.sagemaker import SagemakerConfig from .llms.sagemaker.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig from .llms.ollama_chat import OllamaChatConfig

View file

@ -8,7 +8,7 @@ from openai.types.fine_tuning.fine_tuning_job import FineTuningJob, Hyperparamet
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.llms.base import BaseLLM from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM, VertexLLM,
) )
from litellm.types.llms.openai import FineTuningJobCreate from litellm.types.llms.openai import FineTuningJobCreate

View file

@ -13,7 +13,7 @@ from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client, _get_httpx_client,
) )
from litellm.llms.openai import HttpxBinaryResponseContent from litellm.llms.openai import HttpxBinaryResponseContent
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM, VertexLLM,
) )

View file

@ -69,6 +69,9 @@ def _get_vertex_url(
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}" url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if stream is True: if stream is True:
url += "?alt=sse" url += "?alt=sse"
elif mode == "embedding":
endpoint = "predict"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
return url, endpoint return url, endpoint
@ -79,8 +82,8 @@ def _get_gemini_url(
stream: Optional[bool], stream: Optional[bool],
gemini_api_key: Optional[str], gemini_api_key: Optional[str],
) -> Tuple[str, str]: ) -> Tuple[str, str]:
_gemini_model_name = "models/{}".format(model)
if mode == "chat": if mode == "chat":
_gemini_model_name = "models/{}".format(model)
endpoint = "generateContent" endpoint = "generateContent"
if stream is True: if stream is True:
endpoint = "streamGenerateContent" endpoint = "streamGenerateContent"
@ -94,5 +97,8 @@ def _get_gemini_url(
) )
) )
elif mode == "embedding": elif mode == "embedding":
pass endpoint = "embedContent"
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
return url, endpoint return url, endpoint

View file

@ -11,8 +11,10 @@ from litellm.types.llms.vertex_ai import CachedContentRequestBody, SystemInstruc
from litellm.utils import is_cached_message from litellm.utils import is_cached_message
from ..common_utils import VertexAIError, get_supports_system_message from ..common_utils import VertexAIError, get_supports_system_message
from ..gemini_transformation import transform_system_message from ..gemini.transformation import transform_system_message
from ..vertex_and_google_ai_studio_gemini import _gemini_convert_messages_with_history from ..gemini.vertex_and_google_ai_studio_gemini import (
_gemini_convert_messages_with_history,
)
def separate_cached_messages( def separate_cached_messages(

View file

@ -0,0 +1,121 @@
"""
Google AI Studio Embeddings Endpoint
"""
import json
from typing import Literal, Optional, Union
import httpx
import litellm
from litellm import EmbeddingResponse
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from .vertex_and_google_ai_studio_gemini import VertexLLM
class GoogleEmbeddings(VertexLLM):
def text_embeddings(
self,
model: str,
input: Union[list, str],
print_verbose,
model_response: EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
logging_obj=None,
encoding=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=False,
timeout=300,
client=None,
) -> EmbeddingResponse:
return model_response
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=None,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=False,
mode="embedding",
)
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
optional_params = optional_params or {}
# request_data = VertexMultimodalEmbeddingRequest()
# if "instances" in optional_params:
# request_data["instances"] = optional_params["instances"]
# elif isinstance(input, list):
# request_data["instances"] = input
# else:
# # construct instances
# vertex_request_instance = Instance(**optional_params)
# if isinstance(input, str):
# vertex_request_instance["text"] = input
# request_data["instances"] = [vertex_request_instance]
# headers = {
# "Content-Type": "application/json; charset=utf-8",
# "Authorization": f"Bearer {auth_header}",
# }
# ## LOGGING
# logging_obj.pre_call(
# input=input,
# api_key="",
# additional_args={
# "complete_input_dict": request_data,
# "api_base": url,
# "headers": headers,
# },
# )
# if aembedding is True:
# pass
# response = sync_handler.post(
# url=url,
# headers=headers,
# data=json.dumps(request_data),
# )
# if response.status_code != 200:
# raise Exception(f"Error: {response.status_code} {response.text}")
# _json_response = response.json()
# if "predictions" not in _json_response:
# raise litellm.InternalServerError(
# message=f"embedding response does not contain 'predictions', got {_json_response}",
# llm_provider="vertex_ai",
# model=model,
# )
# _predictions = _json_response["predictions"]
# model_response.data = _predictions
# model_response.model = model
# return model_response

View file

@ -0,0 +1,5 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /embedContent format.
Why separate file? Make it easy to see how transformation works
"""

View file

@ -1813,6 +1813,7 @@ class VertexLLM(BaseLLM):
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
api_base=api_base, api_base=api_base,
should_use_v1beta1_features=False, should_use_v1beta1_features=False,
mode="embedding",
) )
if client is None: if client is None:
@ -1828,11 +1829,6 @@ class VertexLLM(BaseLLM):
else: else:
sync_handler = client # type: ignore sync_handler = client # type: ignore
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
optional_params = optional_params or {} optional_params = optional_params or {}
request_data = VertexMultimodalEmbeddingRequest() request_data = VertexMultimodalEmbeddingRequest()
@ -1850,30 +1846,22 @@ class VertexLLM(BaseLLM):
request_data["instances"] = [vertex_request_instance] request_data["instances"] = [vertex_request_instance]
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
logging_obj.pre_call(
input=[],
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
logging_obj.pre_call(
input=[],
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
headers = { headers = {
"Content-Type": "application/json; charset=utf-8", "Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}", "Authorization": f"Bearer {auth_header}",
} }
## LOGGING
logging_obj.pre_call(
input=input,
api_key="",
additional_args={
"complete_input_dict": request_data,
"api_base": url,
"headers": headers,
},
)
if aembedding is True: if aembedding is True:
return self.async_multimodal_embedding( return self.async_multimodal_embedding(
model=model, model=model,

View file

@ -205,7 +205,7 @@ def get_vertex_client(
vertex_credentials: Optional[str], vertex_credentials: Optional[str],
) -> Tuple[Any, Optional[str]]: ) -> Tuple[Any, Optional[str]]:
args = locals() args = locals()
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM, VertexLLM,
) )
@ -270,7 +270,7 @@ def completion(
from anthropic import AnthropicVertex from anthropic import AnthropicVertex
from litellm.llms.anthropic import AnthropicChatCompletion from litellm.llms.anthropic import AnthropicChatCompletion
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM, VertexLLM,
) )
except: except:

View file

@ -3134,6 +3134,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
or custom_llm_provider == "fireworks_ai" or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "ollama" or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "gemini"
or custom_llm_provider == "databricks" or custom_llm_provider == "databricks"
or custom_llm_provider == "watsonx" or custom_llm_provider == "watsonx"
or custom_llm_provider == "cohere" or custom_llm_provider == "cohere"
@ -3528,6 +3529,26 @@ def embedding(
client=client, client=client,
aembedding=aembedding, aembedding=aembedding,
) )
elif custom_llm_provider == "gemini":
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
response = vertex_chat_completion.multimodal_embedding( # type: ignore
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=aembedding,
print_verbose=print_verbose,
custom_llm_provider="gemini",
api_key=gemini_api_key,
)
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":
vertex_ai_project = ( vertex_ai_project = (
optional_params.pop("vertex_project", None) optional_params.pop("vertex_project", None)

View file

@ -686,6 +686,22 @@ async def test_triton_embeddings():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_gemini_embeddings():
try:
litellm.set_verbose = True
response = await litellm.aembedding(
model="gemini/text-embedding-004",
input=["good morning from litellm"],
)
print(f"response: {response}")
# stubbed endpoint is setup to return this
assert response.data[0]["embedding"] == [0.1, 0.2]
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_databricks_embeddings(sync_mode): async def test_databricks_embeddings(sync_mode):