mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
VertexAI non-jsonl file storage support (#9781)
* test: add initial e2e test * fix(vertex_ai/files): initial commit adding sync file create support * refactor: initial commit of vertex ai non-jsonl files reaching gcp endpoint * fix(vertex_ai/files/transformation.py): initial working commit of non-jsonl file call reaching backend endpoint * fix(vertex_ai/files/transformation.py): working e2e non-jsonl file upload * test: working e2e jsonl call * test: unit testing for jsonl file creation * fix(vertex_ai/transformation.py): reset file pointer after read allow multiple reads on same file object * fix: fix linting errors * fix: fix ruff linting errors * fix: fix import * fix: fix linting error * fix: fix linting error * fix(vertex_ai/files/transformation.py): fix linting error * test: update test * test: update tests * fix: fix linting errors * fix: fix test * fix: fix linting error
This commit is contained in:
parent
93532e00db
commit
6ba3c4a4f8
64 changed files with 780 additions and 185 deletions
|
@ -1,15 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
get_args,
|
||||
)
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union, get_args
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -35,8 +26,9 @@ hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://hug
|
|||
]
|
||||
|
||||
|
||||
|
||||
def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
|
||||
def get_hf_task_embedding_for_model(
|
||||
model: str, task_type: Optional[str], api_base: str
|
||||
) -> Optional[str]:
|
||||
if task_type is not None:
|
||||
if task_type in get_args(hf_tasks_embeddings):
|
||||
return task_type
|
||||
|
@ -57,7 +49,9 @@ def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_ba
|
|||
return pipeline_tag
|
||||
|
||||
|
||||
async def async_get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
|
||||
async def async_get_hf_task_embedding_for_model(
|
||||
model: str, task_type: Optional[str], api_base: str
|
||||
) -> Optional[str]:
|
||||
if task_type is not None:
|
||||
if task_type in get_args(hf_tasks_embeddings):
|
||||
return task_type
|
||||
|
@ -116,7 +110,9 @@ class HuggingFaceEmbedding(BaseLLM):
|
|||
input: List,
|
||||
optional_params: dict,
|
||||
) -> dict:
|
||||
hf_task = await async_get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
|
||||
hf_task = await async_get_hf_task_embedding_for_model(
|
||||
model=model, task_type=task_type, api_base=HF_HUB_URL
|
||||
)
|
||||
|
||||
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
|
||||
|
||||
|
@ -173,7 +169,9 @@ class HuggingFaceEmbedding(BaseLLM):
|
|||
task_type = optional_params.pop("input_type", None)
|
||||
|
||||
if call_type == "sync":
|
||||
hf_task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
|
||||
hf_task = get_hf_task_embedding_for_model(
|
||||
model=model, task_type=task_type, api_base=HF_HUB_URL
|
||||
)
|
||||
elif call_type == "async":
|
||||
return self._async_transform_input(
|
||||
model=model, task_type=task_type, embed_url=embed_url, input=input
|
||||
|
@ -325,6 +323,7 @@ class HuggingFaceEmbedding(BaseLLM):
|
|||
input: list,
|
||||
model_response: EmbeddingResponse,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
encoding: Callable,
|
||||
api_key: Optional[str] = None,
|
||||
|
@ -341,9 +340,12 @@ class HuggingFaceEmbedding(BaseLLM):
|
|||
model=model,
|
||||
optional_params=optional_params,
|
||||
messages=[],
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
task_type = optional_params.pop("input_type", None)
|
||||
task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
|
||||
task = get_hf_task_embedding_for_model(
|
||||
model=model, task_type=task_type, api_base=HF_HUB_URL
|
||||
)
|
||||
# print_verbose(f"{model}, {task}")
|
||||
embed_url = ""
|
||||
if "https" in model:
|
||||
|
@ -355,7 +357,9 @@ class HuggingFaceEmbedding(BaseLLM):
|
|||
elif "HUGGINGFACE_API_BASE" in os.environ:
|
||||
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
|
||||
else:
|
||||
embed_url = f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
|
||||
embed_url = (
|
||||
f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
|
||||
)
|
||||
|
||||
## ROUTING ##
|
||||
if aembedding is True:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue