forked from phoenix/litellm-mirror
[Vertex Multimodal embeddings] Fixes to work with Langchain OpenAI Embedding (#5949)
* fix parallel request limiter - use one cache update call * ci/cd run again * run ci/cd again * use docker username password * fix config.yml * fix config * fix config * fix config.yml * ci/cd run again * use correct typing for batch set cache * fix async_set_cache_pipeline * fix only check user id tpm / rpm limits when limits set * fix test_openai_azure_embedding_with_oidc_and_cf * add InstanceImage type * fix vertex image transform * add langchain vertex test request * add new vertex test * update multimodal embedding tests * add test_vertexai_multimodal_embedding_base64image_in_input * simplify langchain mm embedding usage * add langchain example for multimodal embeddings on vertex * fix linting error
This commit is contained in:
parent
bd17424c4b
commit
fd87ae69b8
6 changed files with 318 additions and 54 deletions
|
@ -1684,17 +1684,21 @@ Usage
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="sdk" label="SDK">
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
Using GCS Images
|
||||||
|
|
||||||
```python
|
```python
|
||||||
response = await litellm.aembedding(
|
response = await litellm.aembedding(
|
||||||
model="vertex_ai/multimodalembedding@001",
|
model="vertex_ai/multimodalembedding@001",
|
||||||
input=[
|
input="gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png" # will be sent as a gcs image
|
||||||
{
|
)
|
||||||
"image": {
|
```
|
||||||
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
|
||||||
},
|
Using base 64 encoded images
|
||||||
"text": "this is a unicorn",
|
|
||||||
},
|
```python
|
||||||
],
|
response = await litellm.aembedding(
|
||||||
|
model="vertex_ai/multimodalembedding@001",
|
||||||
|
input="data:image/jpeg;base64,..." # will be sent as a base64 encoded image
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -1721,9 +1725,15 @@ litellm_settings:
|
||||||
$ litellm --config /path/to/config.yaml
|
$ litellm --config /path/to/config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Make Request use OpenAI Python SDK
|
3. Make Request use OpenAI Python SDK, Langchain Python SDK
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="OpenAI SDK" label="OpenAI SDK">
|
||||||
|
|
||||||
|
Requests with GCS Image / Video URI
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
@ -1732,23 +1742,13 @@ client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||||
# # request sent to model set on litellm proxy, `litellm --model`
|
# # request sent to model set on litellm proxy, `litellm --model`
|
||||||
response = client.embeddings.create(
|
response = client.embeddings.create(
|
||||||
model="multimodalembedding@001",
|
model="multimodalembedding@001",
|
||||||
input = None,
|
input = "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png",
|
||||||
extra_body = {
|
|
||||||
"instances": [
|
|
||||||
{
|
|
||||||
"image": {
|
|
||||||
"bytesBase64Encoded": "base64"
|
|
||||||
},
|
|
||||||
"text": "this is a unicorn",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Requests with base64 encoded images
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import openai
|
import openai
|
||||||
|
@ -1758,23 +1758,63 @@ client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||||
# # request sent to model set on litellm proxy, `litellm --model`
|
# # request sent to model set on litellm proxy, `litellm --model`
|
||||||
response = client.embeddings.create(
|
response = client.embeddings.create(
|
||||||
model="multimodalembedding@001",
|
model="multimodalembedding@001",
|
||||||
input = None,
|
input = "data:image/jpeg;base64,...",
|
||||||
extra_body = {
|
|
||||||
"instances": [
|
|
||||||
{
|
|
||||||
"image": {
|
|
||||||
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
|
||||||
},
|
|
||||||
"text": "this is a unicorn",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="langchain" label="Langchain">
|
||||||
|
|
||||||
|
Requests with GCS Image / Video URI
|
||||||
|
```python
|
||||||
|
from langchain_openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
embeddings_models = "multimodalembedding@001"
|
||||||
|
|
||||||
|
embeddings = OpenAIEmbeddings(
|
||||||
|
model="multimodalembedding@001",
|
||||||
|
base_url="http://0.0.0.0:4000",
|
||||||
|
api_key="sk-1234", # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
query_result = embeddings.embed_query(
|
||||||
|
"gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
||||||
|
)
|
||||||
|
print(query_result)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Requests with base64 encoded images
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain_openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
embeddings_models = "multimodalembedding@001"
|
||||||
|
|
||||||
|
embeddings = OpenAIEmbeddings(
|
||||||
|
model="multimodalembedding@001",
|
||||||
|
base_url="http://0.0.0.0:4000",
|
||||||
|
api_key="sk-1234", # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
query_result = embeddings.embed_query(
|
||||||
|
"data:image/jpeg;base64,..."
|
||||||
|
)
|
||||||
|
print(query_result)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
|
||||||
<TabItem value="proxy-vtx" label="LiteLLM PROXY (Vertex SDK)">
|
<TabItem value="proxy-vtx" label="LiteLLM PROXY (Vertex SDK)">
|
||||||
|
|
||||||
1. Add model to config.yaml
|
1. Add model to config.yaml
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexAIError,
|
VertexAIError,
|
||||||
|
@ -11,9 +12,14 @@ from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_stu
|
||||||
)
|
)
|
||||||
from litellm.types.llms.vertex_ai import (
|
from litellm.types.llms.vertex_ai import (
|
||||||
Instance,
|
Instance,
|
||||||
|
InstanceImage,
|
||||||
InstanceVideo,
|
InstanceVideo,
|
||||||
|
MultimodalPrediction,
|
||||||
|
MultimodalPredictions,
|
||||||
VertexMultimodalEmbeddingRequest,
|
VertexMultimodalEmbeddingRequest,
|
||||||
)
|
)
|
||||||
|
from litellm.types.utils import Embedding
|
||||||
|
from litellm.utils import is_base64_encoded
|
||||||
|
|
||||||
|
|
||||||
class VertexMultimodalEmbedding(VertexLLM):
|
class VertexMultimodalEmbedding(VertexLLM):
|
||||||
|
@ -32,9 +38,9 @@ class VertexMultimodalEmbedding(VertexLLM):
|
||||||
model_response: litellm.EmbeddingResponse,
|
model_response: litellm.EmbeddingResponse,
|
||||||
custom_llm_provider: Literal["gemini", "vertex_ai"],
|
custom_llm_provider: Literal["gemini", "vertex_ai"],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
logging_obj=None,
|
|
||||||
encoding=None,
|
encoding=None,
|
||||||
vertex_project=None,
|
vertex_project=None,
|
||||||
vertex_location=None,
|
vertex_location=None,
|
||||||
|
@ -94,7 +100,7 @@ class VertexMultimodalEmbedding(VertexLLM):
|
||||||
vertex_request_instance = Instance(**optional_params)
|
vertex_request_instance = Instance(**optional_params)
|
||||||
|
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
vertex_request_instance["text"] = input
|
vertex_request_instance = self._process_input_element(input)
|
||||||
|
|
||||||
request_data["instances"] = [vertex_request_instance]
|
request_data["instances"] = [vertex_request_instance]
|
||||||
|
|
||||||
|
@ -142,8 +148,10 @@ class VertexMultimodalEmbedding(VertexLLM):
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
_predictions = _json_response["predictions"]
|
_predictions = _json_response["predictions"]
|
||||||
|
vertex_predictions = MultimodalPredictions(predictions=_predictions)
|
||||||
model_response.data = _predictions
|
model_response.data = self.transform_embedding_response_to_openai(
|
||||||
|
predictions=vertex_predictions
|
||||||
|
)
|
||||||
model_response.model = model
|
model_response.model = model
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
@ -186,11 +194,36 @@ class VertexMultimodalEmbedding(VertexLLM):
|
||||||
)
|
)
|
||||||
_predictions = _json_response["predictions"]
|
_predictions = _json_response["predictions"]
|
||||||
|
|
||||||
model_response.data = _predictions
|
vertex_predictions = MultimodalPredictions(predictions=_predictions)
|
||||||
|
model_response.data = self.transform_embedding_response_to_openai(
|
||||||
|
predictions=vertex_predictions
|
||||||
|
)
|
||||||
model_response.model = model
|
model_response.model = model
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
def _process_input_element(self, input_element: str) -> Instance:
|
||||||
|
"""
|
||||||
|
Process the input element for multimodal embedding requests. checks if the if the input is gcs uri, base64 encoded image or plain text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_element (str): The input element to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: A dictionary representing the processed input element.
|
||||||
|
"""
|
||||||
|
if len(input_element) == 0:
|
||||||
|
return Instance(text=input_element)
|
||||||
|
elif "gs://" in input_element:
|
||||||
|
if "mp4" in input_element:
|
||||||
|
return Instance(video=InstanceVideo(gcsUri=input_element))
|
||||||
|
else:
|
||||||
|
return Instance(image=InstanceImage(gcsUri=input_element))
|
||||||
|
elif is_base64_encoded(s=input_element):
|
||||||
|
return Instance(image=InstanceImage(bytesBase64Encoded=input_element))
|
||||||
|
else:
|
||||||
|
return Instance(text=input_element)
|
||||||
|
|
||||||
def process_openai_embedding_input(
|
def process_openai_embedding_input(
|
||||||
self, _input: Union[list, str]
|
self, _input: Union[list, str]
|
||||||
) -> List[Instance]:
|
) -> List[Instance]:
|
||||||
|
@ -211,14 +244,45 @@ class VertexMultimodalEmbedding(VertexLLM):
|
||||||
_input_list = _input
|
_input_list = _input
|
||||||
|
|
||||||
processed_instances = []
|
processed_instances = []
|
||||||
for element in _input:
|
for element in _input_list:
|
||||||
if not isinstance(element, dict):
|
if isinstance(element, str):
|
||||||
# assuming that input is a list of strings
|
instance = Instance(**self._process_input_element(element))
|
||||||
# example: input = ["hello from litellm"]
|
elif isinstance(element, dict):
|
||||||
instance = Instance(text=element)
|
|
||||||
else:
|
|
||||||
# assume this is a
|
|
||||||
instance = Instance(**element)
|
instance = Instance(**element)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported input type: {type(element)}")
|
||||||
processed_instances.append(instance)
|
processed_instances.append(instance)
|
||||||
|
|
||||||
return processed_instances
|
return processed_instances
|
||||||
|
|
||||||
|
def transform_embedding_response_to_openai(
|
||||||
|
self, predictions: MultimodalPredictions
|
||||||
|
) -> List[Embedding]:
|
||||||
|
|
||||||
|
openai_embeddings: List[Embedding] = []
|
||||||
|
if "predictions" in predictions:
|
||||||
|
for idx, _prediction in enumerate(predictions["predictions"]):
|
||||||
|
if _prediction:
|
||||||
|
if "textEmbedding" in _prediction:
|
||||||
|
openai_embedding_object = Embedding(
|
||||||
|
embedding=_prediction["textEmbedding"],
|
||||||
|
index=idx,
|
||||||
|
object="embedding",
|
||||||
|
)
|
||||||
|
openai_embeddings.append(openai_embedding_object)
|
||||||
|
elif "imageEmbedding" in _prediction:
|
||||||
|
openai_embedding_object = Embedding(
|
||||||
|
embedding=_prediction["imageEmbedding"],
|
||||||
|
index=idx,
|
||||||
|
object="embedding",
|
||||||
|
)
|
||||||
|
openai_embeddings.append(openai_embedding_object)
|
||||||
|
elif "videoEmbeddings" in _prediction:
|
||||||
|
for video_embedding in _prediction["videoEmbeddings"]:
|
||||||
|
openai_embedding_object = Embedding(
|
||||||
|
embedding=video_embedding["embedding"],
|
||||||
|
index=idx,
|
||||||
|
object="embedding",
|
||||||
|
)
|
||||||
|
openai_embeddings.append(openai_embedding_object)
|
||||||
|
return openai_embeddings
|
||||||
|
|
|
@ -1,4 +1,14 @@
|
||||||
model_list:
|
model_list:
|
||||||
|
- model_name: multimodalembedding@001
|
||||||
|
litellm_params:
|
||||||
|
model: vertex_ai/multimodalembedding@001
|
||||||
|
vertex_project: "adroit-crow-413218"
|
||||||
|
vertex_location: "us-central1"
|
||||||
|
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json"
|
||||||
|
- model_name: text-embedding-ada-002
|
||||||
|
litellm_params:
|
||||||
|
model: openai/text-embedding-ada-002 # The `openai/` prefix will call openai.chat.completions.create
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
- model_name: db-openai-endpoint
|
- model_name: db-openai-endpoint
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/gpt-3.5-turbo
|
model: openai/gpt-3.5-turbo
|
||||||
|
@ -23,11 +33,10 @@ general_settings:
|
||||||
service_account_settings:
|
service_account_settings:
|
||||||
enforced_params: ["user"]
|
enforced_params: ["user"]
|
||||||
|
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
cache: true
|
drop_params: True
|
||||||
# callbacks: ["otel"]
|
callbacks: ["otel"]
|
||||||
|
success_callback: ["langfuse"]
|
||||||
|
failure_callback: ["langfuse"]
|
||||||
|
|
||||||
|
|
||||||
general_settings:
|
|
||||||
service_account_settings:
|
|
||||||
enforced_params: ["user"]
|
|
||||||
|
|
17
litellm/proxy/tests/test_langchain_embedding.py
Normal file
17
litellm/proxy/tests/test_langchain_embedding.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
from langchain_openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
embeddings_models = "multimodalembedding@001"
|
||||||
|
|
||||||
|
embeddings = OpenAIEmbeddings(
|
||||||
|
model="multimodalembedding@001",
|
||||||
|
base_url="http://0.0.0.0:4000",
|
||||||
|
api_key="sk-1234", # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
query_result = embeddings.embed_query(
|
||||||
|
"gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
||||||
|
)
|
||||||
|
# print(len(query_result))
|
||||||
|
# print(query_result[:5])
|
||||||
|
print(query_result)
|
|
@ -1931,8 +1931,6 @@ async def test_vertexai_multimodal_embedding():
|
||||||
assert response.model == "multimodalembedding@001"
|
assert response.model == "multimodalembedding@001"
|
||||||
assert len(response.data) == 1
|
assert len(response.data) == 1
|
||||||
response_data = response.data[0]
|
response_data = response.data[0]
|
||||||
assert "imageEmbedding" in response_data
|
|
||||||
assert "textEmbedding" in response_data
|
|
||||||
|
|
||||||
# Optional: Print for debugging
|
# Optional: Print for debugging
|
||||||
print("Arguments passed to Vertex AI:", args_to_vertexai)
|
print("Arguments passed to Vertex AI:", args_to_vertexai)
|
||||||
|
@ -1987,7 +1985,121 @@ async def test_vertexai_multimodal_embedding_text_input():
|
||||||
assert response.model == "multimodalembedding@001"
|
assert response.model == "multimodalembedding@001"
|
||||||
assert len(response.data) == 1
|
assert len(response.data) == 1
|
||||||
response_data = response.data[0]
|
response_data = response.data[0]
|
||||||
assert "textEmbedding" in response_data
|
assert response_data["embedding"] == [0.4, 0.5, 0.6]
|
||||||
|
|
||||||
|
# Optional: Print for debugging
|
||||||
|
print("Arguments passed to Vertex AI:", args_to_vertexai)
|
||||||
|
print("Response:", response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vertexai_multimodal_embedding_image_in_input():
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"predictions": [
|
||||||
|
{
|
||||||
|
"imageEmbedding": [0.1, 0.2, 0.3], # Simplified example
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
expected_payload = {
|
||||||
|
"instances": [
|
||||||
|
{
|
||||||
|
"image": {
|
||||||
|
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
# Act: Call the litellm.aembedding function
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="vertex_ai/multimodalembedding@001",
|
||||||
|
input=["gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
_, kwargs = mock_post.call_args
|
||||||
|
args_to_vertexai = kwargs["json"]
|
||||||
|
|
||||||
|
print("args to vertex ai call:", args_to_vertexai)
|
||||||
|
|
||||||
|
assert args_to_vertexai == expected_payload
|
||||||
|
assert response.model == "multimodalembedding@001"
|
||||||
|
assert len(response.data) == 1
|
||||||
|
response_data = response.data[0]
|
||||||
|
|
||||||
|
assert response_data["embedding"] == [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
# Optional: Print for debugging
|
||||||
|
print("Arguments passed to Vertex AI:", args_to_vertexai)
|
||||||
|
print("Response:", response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vertexai_multimodal_embedding_base64image_in_input():
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
|
||||||
|
image_path = "../proxy/cached_logo.jpg"
|
||||||
|
# Getting the base64 string
|
||||||
|
base64_image = encode_image(image_path)
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"predictions": [
|
||||||
|
{
|
||||||
|
"imageEmbedding": [0.1, 0.2, 0.3], # Simplified example
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
expected_payload = {
|
||||||
|
"instances": [
|
||||||
|
{
|
||||||
|
"image": {"bytesBase64Encoded": base64_image},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
# Act: Call the litellm.aembedding function
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="vertex_ai/multimodalembedding@001",
|
||||||
|
input=[base64_image],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
_, kwargs = mock_post.call_args
|
||||||
|
args_to_vertexai = kwargs["json"]
|
||||||
|
|
||||||
|
print("args to vertex ai call:", args_to_vertexai)
|
||||||
|
|
||||||
|
assert args_to_vertexai == expected_payload
|
||||||
|
assert response.model == "multimodalembedding@001"
|
||||||
|
assert len(response.data) == 1
|
||||||
|
response_data = response.data[0]
|
||||||
|
|
||||||
|
assert response_data["embedding"] == [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
# Optional: Print for debugging
|
# Optional: Print for debugging
|
||||||
print("Arguments passed to Vertex AI:", args_to_vertexai)
|
print("Arguments passed to Vertex AI:", args_to_vertexai)
|
||||||
|
|
|
@ -339,9 +339,15 @@ class InstanceVideo(TypedDict, total=False):
|
||||||
videoSegmentConfig: Tuple[float, float, float]
|
videoSegmentConfig: Tuple[float, float, float]
|
||||||
|
|
||||||
|
|
||||||
|
class InstanceImage(TypedDict, total=False):
|
||||||
|
gcsUri: Optional[str]
|
||||||
|
bytesBase64Encoded: Optional[str]
|
||||||
|
mimeType: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class Instance(TypedDict, total=False):
|
class Instance(TypedDict, total=False):
|
||||||
text: str
|
text: str
|
||||||
image: Dict[str, str]
|
image: InstanceImage
|
||||||
video: InstanceVideo
|
video: InstanceVideo
|
||||||
|
|
||||||
|
|
||||||
|
@ -349,6 +355,22 @@ class VertexMultimodalEmbeddingRequest(TypedDict, total=False):
|
||||||
instances: List[Instance]
|
instances: List[Instance]
|
||||||
|
|
||||||
|
|
||||||
|
class VideoEmbedding(TypedDict):
|
||||||
|
startOffsetSec: int
|
||||||
|
endOffsetSec: int
|
||||||
|
embedding: List[float]
|
||||||
|
|
||||||
|
|
||||||
|
class MultimodalPrediction(TypedDict, total=False):
|
||||||
|
textEmbedding: List[float]
|
||||||
|
imageEmbedding: List[float]
|
||||||
|
videoEmbeddings: List[VideoEmbedding]
|
||||||
|
|
||||||
|
|
||||||
|
class MultimodalPredictions(TypedDict, total=False):
|
||||||
|
predictions: List[MultimodalPrediction]
|
||||||
|
|
||||||
|
|
||||||
class VertexAICachedContentResponseObject(TypedDict):
|
class VertexAICachedContentResponseObject(TypedDict):
|
||||||
name: str
|
name: str
|
||||||
model: str
|
model: str
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue