VertexAI non-jsonl file storage support (#9781)

* test: add initial e2e test

* fix(vertex_ai/files): initial commit adding sync file create support

* refactor: initial commit of vertex ai non-jsonl files reaching gcp endpoint

* fix(vertex_ai/files/transformation.py): initial working commit of non-jsonl file call reaching backend endpoint

* fix(vertex_ai/files/transformation.py): working e2e non-jsonl file upload

* test: working e2e jsonl call

* test: unit testing for jsonl file creation

* fix(vertex_ai/transformation.py): reset file pointer after read

allow multiple reads on same file object

* fix: fix linting errors

* fix: fix ruff linting errors

* fix: fix import

* fix: fix linting error

* fix: fix linting error

* fix(vertex_ai/files/transformation.py): fix linting error

* test: update test

* test: update tests

* fix: fix linting errors

* fix: fix test

* fix: fix linting error
This commit is contained in:
Krish Dholakia 2025-04-09 14:01:48 -07:00 committed by GitHub
parent 93532e00db
commit 6ba3c4a4f8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
64 changed files with 780 additions and 185 deletions

View file

@ -1,15 +1,6 @@
import json
import os
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Union,
get_args,
)
from typing import Any, Callable, Dict, List, Literal, Optional, Union, get_args
import httpx
@ -35,8 +26,9 @@ hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://hug
]
def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
def get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
@ -57,7 +49,9 @@ def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_ba
return pipeline_tag
async def async_get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
async def async_get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
@ -116,7 +110,9 @@ class HuggingFaceEmbedding(BaseLLM):
input: List,
optional_params: dict,
) -> dict:
hf_task = await async_get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
hf_task = await async_get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=HF_HUB_URL
)
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
@ -173,7 +169,9 @@ class HuggingFaceEmbedding(BaseLLM):
task_type = optional_params.pop("input_type", None)
if call_type == "sync":
hf_task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
hf_task = get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=HF_HUB_URL
)
elif call_type == "async":
return self._async_transform_input(
model=model, task_type=task_type, embed_url=embed_url, input=input
@ -325,6 +323,7 @@ class HuggingFaceEmbedding(BaseLLM):
input: list,
model_response: EmbeddingResponse,
optional_params: dict,
litellm_params: dict,
logging_obj: LiteLLMLoggingObj,
encoding: Callable,
api_key: Optional[str] = None,
@ -341,9 +340,12 @@ class HuggingFaceEmbedding(BaseLLM):
model=model,
optional_params=optional_params,
messages=[],
litellm_params=litellm_params,
)
task_type = optional_params.pop("input_type", None)
task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
task = get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=HF_HUB_URL
)
# print_verbose(f"{model}, {task}")
embed_url = ""
if "https" in model:
@ -355,7 +357,9 @@ class HuggingFaceEmbedding(BaseLLM):
elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
embed_url = (
f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
)
## ROUTING ##
if aembedding is True: