fix map input_type to task_type for vertex ai

This commit is contained in:
Ishaan Jaff 2024-08-30 12:09:07 -07:00
parent 6a5992ed2d
commit ea12519b98
7 changed files with 311 additions and 271 deletions

View file

@ -861,7 +861,7 @@ from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gem
GoogleAIStudioGeminiConfig,
VertexAIConfig,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import (
VertexAITextEmbeddingConfig,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (

View file

@ -1096,269 +1096,3 @@ async def async_streaming(
)
return streamwrapper
class VertexAITextEmbeddingConfig(BaseModel):
"""
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
Args:
auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
"""
auto_truncate: Optional[bool] = None
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None
title: Optional[str] = None
def __init__(
self,
auto_truncate: Optional[bool] = None,
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None,
title: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"dimensions",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "dimensions":
optional_params["output_dimensionality"] = value
return optional_params
def get_mapped_special_auth_params(self) -> dict:
"""
Common auth params across bedrock/vertex_ai/azure/watsonx
"""
return {"project": "vertex_project", "region_name": "vertex_location"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
def embedding(
model: str,
input: Union[list, str],
print_verbose,
model_response: litellm.EmbeddingResponse,
optional_params: dict,
api_key: Optional[str] = None,
logging_obj=None,
encoding=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=False,
):
# logic for parsing in - calling - parsing out model embedding calls
try:
import vertexai
except:
raise VertexAIError(
status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
)
import google.auth # type: ignore
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
try:
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
)
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
json_obj = json.loads(vertex_credentials)
creds = google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
creds, _ = google.auth.default(quota_project_id=vertex_project)
print_verbose(
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
)
vertexai.init(
project=vertex_project, location=vertex_location, credentials=creds
)
except Exception as e:
raise VertexAIError(status_code=401, message=str(e))
if isinstance(input, str):
input = [input]
if optional_params is not None and isinstance(optional_params, dict):
if optional_params.get("task_type") or optional_params.get("title"):
# if user passed task_type or title, cast to TextEmbeddingInput
_task_type = optional_params.pop("task_type", None)
_title = optional_params.pop("title", None)
input = [
TextEmbeddingInput(text=x, task_type=_task_type, title=_title)
for x in input
]
try:
llm_model = TextEmbeddingModel.from_pretrained(model)
except Exception as e:
raise VertexAIError(status_code=422, message=str(e))
if aembedding == True:
return async_embedding(
model=model,
client=llm_model,
input=input,
logging_obj=logging_obj,
model_response=model_response,
optional_params=optional_params,
encoding=encoding,
)
_input_dict = {"texts": input, **optional_params}
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
## LOGGING PRE-CALL
logging_obj.pre_call(
input=input,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
try:
embeddings = llm_model.get_embeddings(**_input_dict)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
## LOGGING POST-CALL
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
## Populate OpenAI compliant dictionary
embedding_response = []
input_tokens: int = 0
for idx, embedding in enumerate(embeddings):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding.values,
}
)
input_tokens += embedding.statistics.token_count # type: ignore
model_response.object = "list"
model_response.data = embedding_response
model_response.model = model
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
setattr(model_response, "usage", usage)
return model_response
async def async_embedding(
model: str,
input: Union[list, str],
model_response: litellm.EmbeddingResponse,
logging_obj=None,
optional_params=None,
encoding=None,
client=None,
):
"""
Async embedding implementation
"""
_input_dict = {"texts": input, **optional_params}
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
## LOGGING PRE-CALL
logging_obj.pre_call(
input=input,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
try:
embeddings = await client.get_embeddings_async(**_input_dict)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
## LOGGING POST-CALL
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
## Populate OpenAI compliant dictionary
embedding_response = []
input_tokens: int = 0
for idx, embedding in enumerate(embeddings):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding.values,
}
)
input_tokens += embedding.statistics.token_count
model_response.object = "list"
model_response.data = embedding_response
model_response.model = model
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
setattr(model_response, "usage", usage)
return model_response

View file

@ -0,0 +1,280 @@
import json
import os
import types
from typing import Literal, Optional, Union
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
VertexAIError,
)
from litellm.types.llms.vertex_ai import *
from litellm.utils import Usage
class VertexAITextEmbeddingConfig(BaseModel):
"""
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
Args:
auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
"""
auto_truncate: Optional[bool] = None
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None
title: Optional[str] = None
def __init__(
self,
auto_truncate: Optional[bool] = None,
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None,
title: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return ["dimensions", "input_type"]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "dimensions":
optional_params["output_dimensionality"] = value
if param == "input_type":
optional_params["task_type"] = value
return optional_params
def get_mapped_special_auth_params(self) -> dict:
"""
Common auth params across bedrock/vertex_ai/azure/watsonx
"""
return {"project": "vertex_project", "region_name": "vertex_location"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
def embedding(
model: str,
input: Union[list, str],
print_verbose,
model_response: litellm.EmbeddingResponse,
optional_params: dict,
api_key: Optional[str] = None,
logging_obj=None,
encoding=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=False,
):
# logic for parsing in - calling - parsing out model embedding calls
try:
import vertexai
except:
raise VertexAIError(
status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
)
import google.auth # type: ignore
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
try:
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
)
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
json_obj = json.loads(vertex_credentials)
creds = google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
creds, _ = google.auth.default(quota_project_id=vertex_project)
print_verbose(
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
)
vertexai.init(
project=vertex_project, location=vertex_location, credentials=creds
)
except Exception as e:
raise VertexAIError(status_code=401, message=str(e))
if isinstance(input, str):
input = [input]
if optional_params is not None and isinstance(optional_params, dict):
if optional_params.get("task_type") or optional_params.get("title"):
# if user passed task_type or title, cast to TextEmbeddingInput
_task_type = optional_params.pop("task_type", None)
_title = optional_params.pop("title", None)
input = [
TextEmbeddingInput(text=x, task_type=_task_type, title=_title)
for x in input
]
try:
llm_model = TextEmbeddingModel.from_pretrained(model)
except Exception as e:
raise VertexAIError(status_code=422, message=str(e))
if aembedding == True:
return async_embedding(
model=model,
client=llm_model,
input=input,
logging_obj=logging_obj,
model_response=model_response,
optional_params=optional_params,
encoding=encoding,
)
_input_dict = {"texts": input, **optional_params}
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
## LOGGING PRE-CALL
logging_obj.pre_call(
input=input,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
try:
embeddings = llm_model.get_embeddings(**_input_dict)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
## LOGGING POST-CALL
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
## Populate OpenAI compliant dictionary
embedding_response = []
input_tokens: int = 0
for idx, embedding in enumerate(embeddings):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding.values,
}
)
input_tokens += embedding.statistics.token_count # type: ignore
model_response.object = "list"
model_response.data = embedding_response
model_response.model = model
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
setattr(model_response, "usage", usage)
return model_response
async def async_embedding(
model: str,
input: Union[list, str],
model_response: litellm.EmbeddingResponse,
logging_obj=None,
optional_params=None,
encoding=None,
client=None,
):
"""
Async embedding implementation
"""
_input_dict = {"texts": input, **optional_params}
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
## LOGGING PRE-CALL
logging_obj.pre_call(
input=input,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
try:
embeddings = await client.get_embeddings_async(**_input_dict)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
## LOGGING POST-CALL
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
## Populate OpenAI compliant dictionary
embedding_response = []
input_tokens: int = 0
for idx, embedding in enumerate(embeddings):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding.values,
}
)
input_tokens += embedding.statistics.token_count
model_response.object = "list"
model_response.data = embedding_response
model_response.model = model
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
setattr(model_response, "usage", usage)
return model_response

View file

@ -126,18 +126,21 @@ from .llms.vertex_ai_and_google_ai_studio import (
vertex_ai_anthropic,
vertex_ai_non_gemini,
)
from .llms.vertex_ai_and_google_ai_studio.embeddings.batch_embed_content_handler import (
GoogleBatchEmbeddings,
)
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from .llms.vertex_ai_and_google_ai_studio.gemini_embeddings.batch_embed_content_handler import (
GoogleBatchEmbeddings,
)
from .llms.vertex_ai_and_google_ai_studio.multimodal_embeddings.embedding_handler import (
VertexMultimodalEmbedding,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
VertexAIPartnerModels,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings import (
embedding_handler as vertex_ai_embedding_handler,
)
from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import (
@ -3606,7 +3609,7 @@ def embedding(
custom_llm_provider="vertex_ai",
)
else:
response = vertex_ai_non_gemini.embedding(
response = vertex_ai_embedding_handler.embedding(
model=model,
input=input,
encoding=encoding,

View file

@ -2014,6 +2014,29 @@ def test_vertexai_embedding_embedding_latest():
pytest.fail(f"Error occurred: {e}")
# @pytest.mark.skip(
# reason="new test - works locally running into vertex version issues on ci/cd"
# )
def test_vertexai_embedding_embedding_latest_input_type():
try:
load_vertex_ai_credentials()
litellm.set_verbose = True
response = embedding(
model="vertex_ai/text-embedding-004",
input=["hi"],
input_type="RETRIEVAL_QUERY",
)
assert len(response.data[0]["embedding"]) == 1
assert response.usage.prompt_tokens > 0
print(f"response:", response)
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_vertexai_aembedding():