mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix: initial commit
This commit is contained in:
parent
d29a7087f1
commit
77e6da78a1
11 changed files with 192 additions and 36 deletions
|
@ -848,7 +848,7 @@ from .llms.gemini import GeminiConfig
|
|||
from .llms.nlp_cloud import NLPCloudConfig
|
||||
from .llms.aleph_alpha import AlephAlphaConfig
|
||||
from .llms.petals import PetalsConfig
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
||||
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexGeminiConfig,
|
||||
GoogleAIStudioGeminiConfig,
|
||||
VertexAIConfig,
|
||||
|
@ -862,9 +862,6 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
|
|||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import (
|
||||
VertexAILlama3Config,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.ai21.transformation import (
|
||||
VertexAIAi21Config,
|
||||
)
|
||||
from .llms.sagemaker.sagemaker import SagemakerConfig
|
||||
from .llms.ollama import OllamaConfig
|
||||
from .llms.ollama_chat import OllamaChatConfig
|
||||
|
|
|
@ -8,7 +8,7 @@ from openai.types.fine_tuning.fine_tuning_job import FineTuningJob, Hyperparamet
|
|||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
from litellm.types.llms.openai import FineTuningJobCreate
|
||||
|
|
|
@ -13,7 +13,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
_get_httpx_client,
|
||||
)
|
||||
from litellm.llms.openai import HttpxBinaryResponseContent
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
|
||||
|
|
|
@ -69,6 +69,9 @@ def _get_vertex_url(
|
|||
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
||||
if stream is True:
|
||||
url += "?alt=sse"
|
||||
elif mode == "embedding":
|
||||
endpoint = "predict"
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||
|
||||
return url, endpoint
|
||||
|
||||
|
@ -79,8 +82,8 @@ def _get_gemini_url(
|
|||
stream: Optional[bool],
|
||||
gemini_api_key: Optional[str],
|
||||
) -> Tuple[str, str]:
|
||||
_gemini_model_name = "models/{}".format(model)
|
||||
if mode == "chat":
|
||||
_gemini_model_name = "models/{}".format(model)
|
||||
endpoint = "generateContent"
|
||||
if stream is True:
|
||||
endpoint = "streamGenerateContent"
|
||||
|
@ -94,5 +97,8 @@ def _get_gemini_url(
|
|||
)
|
||||
)
|
||||
elif mode == "embedding":
|
||||
pass
|
||||
endpoint = "embedContent"
|
||||
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
||||
_gemini_model_name, endpoint, gemini_api_key
|
||||
)
|
||||
return url, endpoint
|
||||
|
|
|
@ -11,8 +11,10 @@ from litellm.types.llms.vertex_ai import CachedContentRequestBody, SystemInstruc
|
|||
from litellm.utils import is_cached_message
|
||||
|
||||
from ..common_utils import VertexAIError, get_supports_system_message
|
||||
from ..gemini_transformation import transform_system_message
|
||||
from ..vertex_and_google_ai_studio_gemini import _gemini_convert_messages_with_history
|
||||
from ..gemini.transformation import transform_system_message
|
||||
from ..gemini.vertex_and_google_ai_studio_gemini import (
|
||||
_gemini_convert_messages_with_history,
|
||||
)
|
||||
|
||||
|
||||
def separate_cached_messages(
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
"""
|
||||
Google AI Studio Embeddings Endpoint
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import EmbeddingResponse
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
from .vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
|
||||
|
||||
class GoogleEmbeddings(VertexLLM):
|
||||
def text_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[list, str],
|
||||
print_verbose,
|
||||
model_response: EmbeddingResponse,
|
||||
custom_llm_provider: Literal["gemini", "vertex_ai"],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
encoding=None,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
aembedding=False,
|
||||
timeout=300,
|
||||
client=None,
|
||||
) -> EmbeddingResponse:
|
||||
return model_response
|
||||
auth_header, url = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=api_key,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=None,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=False,
|
||||
mode="embedding",
|
||||
)
|
||||
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
sync_handler = client # type: ignore
|
||||
|
||||
optional_params = optional_params or {}
|
||||
|
||||
# request_data = VertexMultimodalEmbeddingRequest()
|
||||
|
||||
# if "instances" in optional_params:
|
||||
# request_data["instances"] = optional_params["instances"]
|
||||
# elif isinstance(input, list):
|
||||
# request_data["instances"] = input
|
||||
# else:
|
||||
# # construct instances
|
||||
# vertex_request_instance = Instance(**optional_params)
|
||||
|
||||
# if isinstance(input, str):
|
||||
# vertex_request_instance["text"] = input
|
||||
|
||||
# request_data["instances"] = [vertex_request_instance]
|
||||
|
||||
# headers = {
|
||||
# "Content-Type": "application/json; charset=utf-8",
|
||||
# "Authorization": f"Bearer {auth_header}",
|
||||
# }
|
||||
|
||||
# ## LOGGING
|
||||
# logging_obj.pre_call(
|
||||
# input=input,
|
||||
# api_key="",
|
||||
# additional_args={
|
||||
# "complete_input_dict": request_data,
|
||||
# "api_base": url,
|
||||
# "headers": headers,
|
||||
# },
|
||||
# )
|
||||
|
||||
# if aembedding is True:
|
||||
# pass
|
||||
|
||||
# response = sync_handler.post(
|
||||
# url=url,
|
||||
# headers=headers,
|
||||
# data=json.dumps(request_data),
|
||||
# )
|
||||
|
||||
# if response.status_code != 200:
|
||||
# raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
# _json_response = response.json()
|
||||
# if "predictions" not in _json_response:
|
||||
# raise litellm.InternalServerError(
|
||||
# message=f"embedding response does not contain 'predictions', got {_json_response}",
|
||||
# llm_provider="vertex_ai",
|
||||
# model=model,
|
||||
# )
|
||||
# _predictions = _json_response["predictions"]
|
||||
|
||||
# model_response.data = _predictions
|
||||
# model_response.model = model
|
||||
|
||||
# return model_response
|
|
@ -0,0 +1,5 @@
|
|||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /embedContent format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
|
@ -1813,6 +1813,7 @@ class VertexLLM(BaseLLM):
|
|||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=False,
|
||||
mode="embedding",
|
||||
)
|
||||
|
||||
if client is None:
|
||||
|
@ -1828,11 +1829,6 @@ class VertexLLM(BaseLLM):
|
|||
else:
|
||||
sync_handler = client # type: ignore
|
||||
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
||||
|
||||
auth_header, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
optional_params = optional_params or {}
|
||||
|
||||
request_data = VertexMultimodalEmbeddingRequest()
|
||||
|
@ -1850,30 +1846,22 @@ class VertexLLM(BaseLLM):
|
|||
|
||||
request_data["instances"] = [vertex_request_instance]
|
||||
|
||||
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
|
||||
logging_obj.pre_call(
|
||||
input=[],
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=[],
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": request_data,
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
if aembedding is True:
|
||||
return self.async_multimodal_embedding(
|
||||
model=model,
|
||||
|
|
|
@ -205,7 +205,7 @@ def get_vertex_client(
|
|||
vertex_credentials: Optional[str],
|
||||
) -> Tuple[Any, Optional[str]]:
|
||||
args = locals()
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
|
||||
|
@ -270,7 +270,7 @@ def completion(
|
|||
from anthropic import AnthropicVertex
|
||||
|
||||
from litellm.llms.anthropic import AnthropicChatCompletion
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
except:
|
||||
|
|
|
@ -3134,6 +3134,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
|||
or custom_llm_provider == "fireworks_ai"
|
||||
or custom_llm_provider == "ollama"
|
||||
or custom_llm_provider == "vertex_ai"
|
||||
or custom_llm_provider == "gemini"
|
||||
or custom_llm_provider == "databricks"
|
||||
or custom_llm_provider == "watsonx"
|
||||
or custom_llm_provider == "cohere"
|
||||
|
@ -3528,6 +3529,26 @@ def embedding(
|
|||
client=client,
|
||||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "gemini":
|
||||
|
||||
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
|
||||
|
||||
response = vertex_chat_completion.multimodal_embedding( # type: ignore
|
||||
model=model,
|
||||
input=input,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
aembedding=aembedding,
|
||||
print_verbose=print_verbose,
|
||||
custom_llm_provider="gemini",
|
||||
api_key=gemini_api_key,
|
||||
)
|
||||
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
vertex_ai_project = (
|
||||
optional_params.pop("vertex_project", None)
|
||||
|
|
|
@ -686,6 +686,22 @@ async def test_triton_embeddings():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_embeddings():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response = await litellm.aembedding(
|
||||
model="gemini/text-embedding-004",
|
||||
input=["good morning from litellm"],
|
||||
)
|
||||
print(f"response: {response}")
|
||||
|
||||
# stubbed endpoint is setup to return this
|
||||
assert response.data[0]["embedding"] == [0.1, 0.2]
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_databricks_embeddings(sync_mode):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue