mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Support CRUD endpoints for Managed Files (#9924)
* fix(openai.py): ensure openai file object shows up on logs * fix(managed_files.py): return unified file id as b64 str allows retrieve file id to work as expected * fix(managed_files.py): apply decoded file id transformation * fix: add unit test for file id + decode logic * fix: initial commit for litellm_proxy support with CRUD Endpoints * fix(managed_files.py): support retrieve file operation * fix(managed_files.py): support for DELETE endpoint for files * fix(managed_files.py): retrieve file content support supports retrieve file content api from openai * fix: fix linting error * test: update tests * fix: fix linting error * fix(files/main.py): pass litellm params to azure route * test: fix test
This commit is contained in:
parent
3e427e26c9
commit
3ca82c22b6
14 changed files with 783 additions and 86 deletions
|
@ -473,9 +473,11 @@ def file_delete(
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
litellm_params_dict = get_litellm_params(**kwargs)
|
||||||
### TIMEOUT LOGIC ###
|
### TIMEOUT LOGIC ###
|
||||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
# set timeout for 10 minutes by default
|
# set timeout for 10 minutes by default
|
||||||
|
client = kwargs.get("client")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
|
@ -549,6 +551,8 @@ def file_delete(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
|
client=client,
|
||||||
|
litellm_params=litellm_params_dict,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
@ -774,8 +778,10 @@ def file_content(
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
litellm_params_dict = get_litellm_params(**kwargs)
|
||||||
### TIMEOUT LOGIC ###
|
### TIMEOUT LOGIC ###
|
||||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
client = kwargs.get("client")
|
||||||
# set timeout for 10 minutes by default
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -797,6 +803,7 @@ def file_content(
|
||||||
)
|
)
|
||||||
|
|
||||||
_is_async = kwargs.pop("afile_content", False) is True
|
_is_async = kwargs.pop("afile_content", False) is True
|
||||||
|
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
api_base = (
|
api_base = (
|
||||||
|
@ -858,6 +865,8 @@ def file_content(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
file_content_request=_file_content_request,
|
file_content_request=_file_content_request,
|
||||||
|
client=client,
|
||||||
|
litellm_params=litellm_params_dict,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
|
|
@ -20,8 +20,7 @@ from litellm.types.integrations.argilla import ArgillaItem
|
||||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
AdapterCompletionStreamWrapper,
|
AdapterCompletionStreamWrapper,
|
||||||
EmbeddingResponse,
|
LLMResponseTypes,
|
||||||
ImageResponse,
|
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ModelResponseStream,
|
ModelResponseStream,
|
||||||
StandardCallbackDynamicParams,
|
StandardCallbackDynamicParams,
|
||||||
|
@ -223,7 +222,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
|
response: LLMResponseTypes,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -306,27 +306,6 @@ def get_completion_messages(
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def get_file_ids_from_messages(messages: List[AllMessageValues]) -> List[str]:
|
|
||||||
"""
|
|
||||||
Gets file ids from messages
|
|
||||||
"""
|
|
||||||
file_ids = []
|
|
||||||
for message in messages:
|
|
||||||
if message.get("role") == "user":
|
|
||||||
content = message.get("content")
|
|
||||||
if content:
|
|
||||||
if isinstance(content, str):
|
|
||||||
continue
|
|
||||||
for c in content:
|
|
||||||
if c["type"] == "file":
|
|
||||||
file_object = cast(ChatCompletionFileObject, c)
|
|
||||||
file_object_file_field = file_object["file"]
|
|
||||||
file_id = file_object_file_field.get("file_id")
|
|
||||||
if file_id:
|
|
||||||
file_ids.append(file_id)
|
|
||||||
return file_ids
|
|
||||||
|
|
||||||
|
|
||||||
def get_format_from_file_id(file_id: Optional[str]) -> Optional[str]:
|
def get_format_from_file_id(file_id: Optional[str]) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Gets format from file id
|
Gets format from file id
|
||||||
|
|
|
@ -22,6 +22,8 @@ from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMExcepti
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
AllMessageValues,
|
AllMessageValues,
|
||||||
|
ChatCompletionFileObject,
|
||||||
|
ChatCompletionFileObjectFile,
|
||||||
ChatCompletionImageObject,
|
ChatCompletionImageObject,
|
||||||
ChatCompletionImageUrlObject,
|
ChatCompletionImageUrlObject,
|
||||||
)
|
)
|
||||||
|
@ -188,6 +190,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
||||||
message_content = message.get("content")
|
message_content = message.get("content")
|
||||||
if message_content and isinstance(message_content, list):
|
if message_content and isinstance(message_content, list):
|
||||||
for content_item in message_content:
|
for content_item in message_content:
|
||||||
|
litellm_specific_params = {"format"}
|
||||||
if content_item.get("type") == "image_url":
|
if content_item.get("type") == "image_url":
|
||||||
content_item = cast(ChatCompletionImageObject, content_item)
|
content_item = cast(ChatCompletionImageObject, content_item)
|
||||||
if isinstance(content_item["image_url"], str):
|
if isinstance(content_item["image_url"], str):
|
||||||
|
@ -195,7 +198,6 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
||||||
"url": content_item["image_url"],
|
"url": content_item["image_url"],
|
||||||
}
|
}
|
||||||
elif isinstance(content_item["image_url"], dict):
|
elif isinstance(content_item["image_url"], dict):
|
||||||
litellm_specific_params = {"format"}
|
|
||||||
new_image_url_obj = ChatCompletionImageUrlObject(
|
new_image_url_obj = ChatCompletionImageUrlObject(
|
||||||
**{ # type: ignore
|
**{ # type: ignore
|
||||||
k: v
|
k: v
|
||||||
|
@ -204,6 +206,17 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
content_item["image_url"] = new_image_url_obj
|
content_item["image_url"] = new_image_url_obj
|
||||||
|
elif content_item.get("type") == "file":
|
||||||
|
content_item = cast(ChatCompletionFileObject, content_item)
|
||||||
|
file_obj = content_item["file"]
|
||||||
|
new_file_obj = ChatCompletionFileObjectFile(
|
||||||
|
**{ # type: ignore
|
||||||
|
k: v
|
||||||
|
for k, v in file_obj.items()
|
||||||
|
if k not in litellm_specific_params
|
||||||
|
}
|
||||||
|
)
|
||||||
|
content_item["file"] = new_file_obj
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def transform_request(
|
def transform_request(
|
||||||
|
|
|
@ -140,6 +140,7 @@ class DBSpendUpdateWriter:
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if disable_spend_logs is False:
|
if disable_spend_logs is False:
|
||||||
await self._insert_spend_log_to_db(
|
await self._insert_spend_log_to_db(
|
||||||
payload=payload,
|
payload=payload,
|
||||||
|
|
|
@ -1,41 +1,114 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
|
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
|
||||||
|
|
||||||
from litellm import verbose_logger
|
from litellm import Router, verbose_logger
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||||
extract_file_data,
|
|
||||||
get_file_ids_from_messages,
|
|
||||||
)
|
|
||||||
from litellm.proxy._types import CallTypes, UserAPIKeyAuth
|
from litellm.proxy._types import CallTypes, UserAPIKeyAuth
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
|
AllMessageValues,
|
||||||
|
ChatCompletionFileObject,
|
||||||
CreateFileRequest,
|
CreateFileRequest,
|
||||||
OpenAIFileObject,
|
OpenAIFileObject,
|
||||||
OpenAIFilesPurpose,
|
OpenAIFilesPurpose,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import SpecialEnums
|
from litellm.types.utils import LLMResponseTypes, SpecialEnums
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
||||||
|
from litellm.proxy.utils import PrismaClient as _PrismaClient
|
||||||
|
|
||||||
Span = Union[_Span, Any]
|
Span = Union[_Span, Any]
|
||||||
InternalUsageCache = _InternalUsageCache
|
InternalUsageCache = _InternalUsageCache
|
||||||
|
PrismaClient = _PrismaClient
|
||||||
else:
|
else:
|
||||||
Span = Any
|
Span = Any
|
||||||
InternalUsageCache = Any
|
InternalUsageCache = Any
|
||||||
|
PrismaClient = Any
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFileEndpoints(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def afile_retrieve(
|
||||||
|
self,
|
||||||
|
file_id: str,
|
||||||
|
litellm_parent_otel_span: Optional[Span],
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def afile_list(
|
||||||
|
self, custom_llm_provider: str, **data: dict
|
||||||
|
) -> List[OpenAIFileObject]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def afile_delete(
|
||||||
|
self, custom_llm_provider: str, file_id: str, **data: dict
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(self, internal_usage_cache: InternalUsageCache):
|
def __init__(
|
||||||
|
self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient
|
||||||
|
):
|
||||||
self.internal_usage_cache = internal_usage_cache
|
self.internal_usage_cache = internal_usage_cache
|
||||||
|
self.prisma_client = prisma_client
|
||||||
|
|
||||||
|
async def store_unified_file_id(
|
||||||
|
self,
|
||||||
|
file_id: str,
|
||||||
|
file_object: OpenAIFileObject,
|
||||||
|
litellm_parent_otel_span: Optional[Span],
|
||||||
|
) -> None:
|
||||||
|
key = f"litellm_proxy/{file_id}"
|
||||||
|
verbose_logger.info(
|
||||||
|
f"Storing LiteLLM Managed File object with id={file_id} in cache"
|
||||||
|
)
|
||||||
|
await self.internal_usage_cache.async_set_cache(
|
||||||
|
key=key,
|
||||||
|
value=file_object,
|
||||||
|
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_unified_file_id(
|
||||||
|
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
|
||||||
|
) -> Optional[OpenAIFileObject]:
|
||||||
|
key = f"litellm_proxy/{file_id}"
|
||||||
|
return await self.internal_usage_cache.async_get_cache(
|
||||||
|
key=key,
|
||||||
|
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_unified_file_id(
|
||||||
|
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
key = f"litellm_proxy/{file_id}"
|
||||||
|
## get old value
|
||||||
|
old_value = await self.internal_usage_cache.async_get_cache(
|
||||||
|
key=key,
|
||||||
|
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||||
|
)
|
||||||
|
if old_value is None or not isinstance(old_value, OpenAIFileObject):
|
||||||
|
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|
||||||
|
## delete old value
|
||||||
|
await self.internal_usage_cache.async_set_cache(
|
||||||
|
key=key,
|
||||||
|
value=None,
|
||||||
|
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||||
|
)
|
||||||
|
return old_value
|
||||||
|
|
||||||
async def async_pre_call_hook(
|
async def async_pre_call_hook(
|
||||||
self,
|
self,
|
||||||
|
@ -60,15 +133,82 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
if call_type == CallTypes.completion.value:
|
if call_type == CallTypes.completion.value:
|
||||||
messages = data.get("messages")
|
messages = data.get("messages")
|
||||||
if messages:
|
if messages:
|
||||||
file_ids = get_file_ids_from_messages(messages)
|
file_ids = (
|
||||||
|
self.get_file_ids_and_decode_b64_to_unified_uid_from_messages(
|
||||||
|
messages
|
||||||
|
)
|
||||||
|
)
|
||||||
if file_ids:
|
if file_ids:
|
||||||
model_file_id_mapping = await self.get_model_file_id_mapping(
|
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||||
file_ids, user_api_key_dict.parent_otel_span
|
file_ids, user_api_key_dict.parent_otel_span
|
||||||
)
|
)
|
||||||
|
|
||||||
data["model_file_id_mapping"] = model_file_id_mapping
|
data["model_file_id_mapping"] = model_file_id_mapping
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def get_file_ids_and_decode_b64_to_unified_uid_from_messages(
|
||||||
|
self, messages: List[AllMessageValues]
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Gets file ids from messages
|
||||||
|
"""
|
||||||
|
file_ids = []
|
||||||
|
for message in messages:
|
||||||
|
if message.get("role") == "user":
|
||||||
|
content = message.get("content")
|
||||||
|
if content:
|
||||||
|
if isinstance(content, str):
|
||||||
|
continue
|
||||||
|
for c in content:
|
||||||
|
if c["type"] == "file":
|
||||||
|
file_object = cast(ChatCompletionFileObject, c)
|
||||||
|
file_object_file_field = file_object["file"]
|
||||||
|
file_id = file_object_file_field.get("file_id")
|
||||||
|
if file_id:
|
||||||
|
file_ids.append(
|
||||||
|
_PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid(
|
||||||
|
file_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
file_object_file_field[
|
||||||
|
"file_id"
|
||||||
|
] = _PROXY_LiteLLMManagedFiles._convert_b64_uid_to_unified_uid(
|
||||||
|
file_id
|
||||||
|
)
|
||||||
|
return file_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_b64_uid_to_unified_uid(b64_uid: str) -> str:
|
||||||
|
is_base64_unified_file_id = (
|
||||||
|
_PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(b64_uid)
|
||||||
|
)
|
||||||
|
if is_base64_unified_file_id:
|
||||||
|
return is_base64_unified_file_id
|
||||||
|
else:
|
||||||
|
return b64_uid
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_base64_encoded_unified_file_id(b64_uid: str) -> Union[str, Literal[False]]:
|
||||||
|
# Add padding back if needed
|
||||||
|
padded = b64_uid + "=" * (-len(b64_uid) % 4)
|
||||||
|
# Decode from base64
|
||||||
|
try:
|
||||||
|
decoded = base64.urlsafe_b64decode(padded).decode()
|
||||||
|
if decoded.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||||||
|
return decoded
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def convert_b64_uid_to_unified_uid(self, b64_uid: str) -> str:
|
||||||
|
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(b64_uid)
|
||||||
|
if is_base64_unified_file_id:
|
||||||
|
return is_base64_unified_file_id
|
||||||
|
else:
|
||||||
|
return b64_uid
|
||||||
|
|
||||||
async def get_model_file_id_mapping(
|
async def get_model_file_id_mapping(
|
||||||
self, file_ids: List[str], litellm_parent_otel_span: Span
|
self, file_ids: List[str], litellm_parent_otel_span: Span
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
@ -87,12 +227,17 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
file_id_mapping: Dict[str, Dict[str, str]] = {}
|
file_id_mapping: Dict[str, Dict[str, str]] = {}
|
||||||
litellm_managed_file_ids = []
|
litellm_managed_file_ids = []
|
||||||
|
|
||||||
for file_id in file_ids:
|
for file_id in file_ids:
|
||||||
## CHECK IF FILE ID IS MANAGED BY LITELM
|
## CHECK IF FILE ID IS MANAGED BY LITELM
|
||||||
if file_id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(file_id)
|
||||||
|
|
||||||
|
if is_base64_unified_file_id:
|
||||||
|
litellm_managed_file_ids.append(is_base64_unified_file_id)
|
||||||
|
elif file_id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||||||
litellm_managed_file_ids.append(file_id)
|
litellm_managed_file_ids.append(file_id)
|
||||||
|
|
||||||
if litellm_managed_file_ids:
|
if litellm_managed_file_ids:
|
||||||
|
@ -107,8 +252,24 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
)
|
)
|
||||||
if cached_values:
|
if cached_values:
|
||||||
file_id_mapping[file_id] = cached_values
|
file_id_mapping[file_id] = cached_values
|
||||||
|
|
||||||
return file_id_mapping
|
return file_id_mapping
|
||||||
|
|
||||||
|
async def async_post_call_success_hook(
|
||||||
|
self,
|
||||||
|
data: Dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
response: LLMResponseTypes,
|
||||||
|
) -> Any:
|
||||||
|
if isinstance(response, OpenAIFileObject):
|
||||||
|
asyncio.create_task(
|
||||||
|
self.store_unified_file_id(
|
||||||
|
response.id, response, user_api_key_dict.parent_otel_span
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def return_unified_file_id(
|
async def return_unified_file_id(
|
||||||
file_objects: List[OpenAIFileObject],
|
file_objects: List[OpenAIFileObject],
|
||||||
|
@ -126,15 +287,20 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
file_type, str(uuid.uuid4())
|
file_type, str(uuid.uuid4())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Convert to URL-safe base64 and strip padding
|
||||||
|
base64_unified_file_id = (
|
||||||
|
base64.urlsafe_b64encode(unified_file_id.encode()).decode().rstrip("=")
|
||||||
|
)
|
||||||
|
|
||||||
## CREATE RESPONSE OBJECT
|
## CREATE RESPONSE OBJECT
|
||||||
## CREATE RESPONSE OBJECT
|
|
||||||
response = OpenAIFileObject(
|
response = OpenAIFileObject(
|
||||||
id=unified_file_id,
|
id=base64_unified_file_id,
|
||||||
object="file",
|
object="file",
|
||||||
purpose=cast(OpenAIFilesPurpose, purpose),
|
purpose=cast(OpenAIFilesPurpose, purpose),
|
||||||
created_at=file_objects[0].created_at,
|
created_at=file_objects[0].created_at,
|
||||||
bytes=1234,
|
bytes=file_objects[0].bytes,
|
||||||
filename=str(datetime.now().timestamp()),
|
filename=file_objects[0].filename,
|
||||||
status="uploaded",
|
status="uploaded",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -156,3 +322,77 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def afile_retrieve(
|
||||||
|
self, file_id: str, litellm_parent_otel_span: Optional[Span]
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
stored_file_object = await self.get_unified_file_id(
|
||||||
|
file_id, litellm_parent_otel_span
|
||||||
|
)
|
||||||
|
if stored_file_object:
|
||||||
|
return stored_file_object
|
||||||
|
else:
|
||||||
|
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|
||||||
|
|
||||||
|
async def afile_list(
|
||||||
|
self,
|
||||||
|
purpose: Optional[OpenAIFilesPurpose],
|
||||||
|
litellm_parent_otel_span: Optional[Span],
|
||||||
|
**data: Dict,
|
||||||
|
) -> List[OpenAIFileObject]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def afile_delete(
|
||||||
|
self,
|
||||||
|
file_id: str,
|
||||||
|
litellm_parent_otel_span: Optional[Span],
|
||||||
|
llm_router: Router,
|
||||||
|
**data: Dict,
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
file_id = self.convert_b64_uid_to_unified_uid(file_id)
|
||||||
|
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||||
|
[file_id], litellm_parent_otel_span
|
||||||
|
)
|
||||||
|
specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
|
||||||
|
if specific_model_file_id_mapping:
|
||||||
|
for model_id, file_id in specific_model_file_id_mapping.items():
|
||||||
|
await llm_router.afile_delete(model=model_id, file_id=file_id, **data) # type: ignore
|
||||||
|
|
||||||
|
stored_file_object = await self.delete_unified_file_id(
|
||||||
|
file_id, litellm_parent_otel_span
|
||||||
|
)
|
||||||
|
if stored_file_object:
|
||||||
|
return stored_file_object
|
||||||
|
else:
|
||||||
|
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|
||||||
|
|
||||||
|
async def afile_content(
|
||||||
|
self,
|
||||||
|
file_id: str,
|
||||||
|
litellm_parent_otel_span: Optional[Span],
|
||||||
|
llm_router: Router,
|
||||||
|
**data: Dict,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get the content of a file from first model that has it
|
||||||
|
"""
|
||||||
|
initial_file_id = file_id
|
||||||
|
unified_file_id = self.convert_b64_uid_to_unified_uid(file_id)
|
||||||
|
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||||
|
[unified_file_id], litellm_parent_otel_span
|
||||||
|
)
|
||||||
|
specific_model_file_id_mapping = model_file_id_mapping.get(unified_file_id)
|
||||||
|
if specific_model_file_id_mapping:
|
||||||
|
exception_dict = {}
|
||||||
|
for model_id, file_id in specific_model_file_id_mapping.items():
|
||||||
|
try:
|
||||||
|
return await llm_router.afile_content(model=model_id, file_id=file_id, **data) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
exception_dict[model_id] = str(e)
|
||||||
|
raise Exception(
|
||||||
|
f"LiteLLM Managed File object with id={initial_file_id} not found. Checked model id's: {specific_model_file_id_mapping.keys()}. Errors: {exception_dict}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"LiteLLM Managed File object with id={initial_file_id} not found"
|
||||||
|
)
|
||||||
|
|
|
@ -159,6 +159,51 @@ async def create_file_for_each_model(
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def route_create_file(
|
||||||
|
llm_router: Optional[Router],
|
||||||
|
_create_file_request: CreateFileRequest,
|
||||||
|
purpose: OpenAIFilesPurpose,
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
target_model_names_list: List[str],
|
||||||
|
is_router_model: bool,
|
||||||
|
router_model: Optional[str],
|
||||||
|
custom_llm_provider: str,
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
if (
|
||||||
|
litellm.enable_loadbalancing_on_batch_endpoints is True
|
||||||
|
and is_router_model
|
||||||
|
and router_model is not None
|
||||||
|
):
|
||||||
|
response = await _deprecated_loadbalanced_create_file(
|
||||||
|
llm_router=llm_router,
|
||||||
|
router_model=router_model,
|
||||||
|
_create_file_request=_create_file_request,
|
||||||
|
)
|
||||||
|
elif target_model_names_list:
|
||||||
|
response = await create_file_for_each_model(
|
||||||
|
llm_router=llm_router,
|
||||||
|
_create_file_request=_create_file_request,
|
||||||
|
target_model_names_list=target_model_names_list,
|
||||||
|
purpose=purpose,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# get configs for custom_llm_provider
|
||||||
|
llm_provider_config = get_files_provider_config(
|
||||||
|
custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
if llm_provider_config is not None:
|
||||||
|
# add llm_provider_config to data
|
||||||
|
_create_file_request.update(llm_provider_config)
|
||||||
|
_create_file_request.pop("custom_llm_provider", None) # type: ignore
|
||||||
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
||||||
|
response = await litellm.acreate_file(**_create_file_request, custom_llm_provider=custom_llm_provider) # type: ignore
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/{provider}/v1/files",
|
"/{provider}/v1/files",
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
@ -267,37 +312,17 @@ async def create_file(
|
||||||
file=file_data, purpose=cast(CREATE_FILE_REQUESTS_PURPOSE, purpose), **data
|
file=file_data, purpose=cast(CREATE_FILE_REQUESTS_PURPOSE, purpose), **data
|
||||||
)
|
)
|
||||||
|
|
||||||
response: Optional[OpenAIFileObject] = None
|
response = await route_create_file(
|
||||||
if (
|
llm_router=llm_router,
|
||||||
litellm.enable_loadbalancing_on_batch_endpoints is True
|
_create_file_request=_create_file_request,
|
||||||
and is_router_model
|
purpose=purpose,
|
||||||
and router_model is not None
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
):
|
user_api_key_dict=user_api_key_dict,
|
||||||
response = await _deprecated_loadbalanced_create_file(
|
target_model_names_list=target_model_names_list,
|
||||||
llm_router=llm_router,
|
is_router_model=is_router_model,
|
||||||
router_model=router_model,
|
router_model=router_model,
|
||||||
_create_file_request=_create_file_request,
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
elif target_model_names_list:
|
|
||||||
response = await create_file_for_each_model(
|
|
||||||
llm_router=llm_router,
|
|
||||||
_create_file_request=_create_file_request,
|
|
||||||
target_model_names_list=target_model_names_list,
|
|
||||||
purpose=purpose,
|
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# get configs for custom_llm_provider
|
|
||||||
llm_provider_config = get_files_provider_config(
|
|
||||||
custom_llm_provider=custom_llm_provider
|
|
||||||
)
|
|
||||||
if llm_provider_config is not None:
|
|
||||||
# add llm_provider_config to data
|
|
||||||
_create_file_request.update(llm_provider_config)
|
|
||||||
_create_file_request.pop("custom_llm_provider", None) # type: ignore
|
|
||||||
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
|
||||||
response = await litellm.acreate_file(**_create_file_request, custom_llm_provider=custom_llm_provider) # type: ignore
|
|
||||||
|
|
||||||
if response is None:
|
if response is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -311,6 +336,13 @@ async def create_file(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## POST CALL HOOKS ###
|
||||||
|
_response = await proxy_logging_obj.post_call_success_hook(
|
||||||
|
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||||
|
)
|
||||||
|
if _response is not None and isinstance(_response, OpenAIFileObject):
|
||||||
|
response = _response
|
||||||
|
|
||||||
### RESPONSE HEADERS ###
|
### RESPONSE HEADERS ###
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
model_id = hidden_params.get("model_id", None) or ""
|
model_id = hidden_params.get("model_id", None) or ""
|
||||||
|
@ -392,6 +424,7 @@ async def get_file_content(
|
||||||
from litellm.proxy.proxy_server import (
|
from litellm.proxy.proxy_server import (
|
||||||
add_litellm_data_to_request,
|
add_litellm_data_to_request,
|
||||||
general_settings,
|
general_settings,
|
||||||
|
llm_router,
|
||||||
proxy_config,
|
proxy_config,
|
||||||
proxy_logging_obj,
|
proxy_logging_obj,
|
||||||
version,
|
version,
|
||||||
|
@ -414,9 +447,40 @@ async def get_file_content(
|
||||||
or await get_custom_llm_provider_from_request_body(request=request)
|
or await get_custom_llm_provider_from_request_body(request=request)
|
||||||
or "openai"
|
or "openai"
|
||||||
)
|
)
|
||||||
response = await litellm.afile_content(
|
|
||||||
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
|
## check if file_id is a litellm managed file
|
||||||
|
is_base64_unified_file_id = (
|
||||||
|
_PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(file_id)
|
||||||
)
|
)
|
||||||
|
if is_base64_unified_file_id:
|
||||||
|
managed_files_obj = cast(
|
||||||
|
Optional[_PROXY_LiteLLMManagedFiles],
|
||||||
|
proxy_logging_obj.get_proxy_hook("managed_files"),
|
||||||
|
)
|
||||||
|
if managed_files_obj is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="Managed files hook not found",
|
||||||
|
type="None",
|
||||||
|
param="None",
|
||||||
|
code=500,
|
||||||
|
)
|
||||||
|
if llm_router is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="LLM Router not found",
|
||||||
|
type="None",
|
||||||
|
param="None",
|
||||||
|
code=500,
|
||||||
|
)
|
||||||
|
response = await managed_files_obj.afile_content(
|
||||||
|
file_id=file_id,
|
||||||
|
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||||
|
llm_router=llm_router,
|
||||||
|
**data,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await litellm.afile_content(
|
||||||
|
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
|
@ -539,10 +603,33 @@ async def get_file(
|
||||||
version=version,
|
version=version,
|
||||||
proxy_config=proxy_config,
|
proxy_config=proxy_config,
|
||||||
)
|
)
|
||||||
response = await litellm.afile_retrieve(
|
|
||||||
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
|
## check if file_id is a litellm managed file
|
||||||
|
is_base64_unified_file_id = (
|
||||||
|
_PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(file_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_base64_unified_file_id:
|
||||||
|
managed_files_obj = cast(
|
||||||
|
Optional[_PROXY_LiteLLMManagedFiles],
|
||||||
|
proxy_logging_obj.get_proxy_hook("managed_files"),
|
||||||
|
)
|
||||||
|
if managed_files_obj is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="Managed files hook not found",
|
||||||
|
type="None",
|
||||||
|
param="None",
|
||||||
|
code=500,
|
||||||
|
)
|
||||||
|
response = await managed_files_obj.afile_retrieve(
|
||||||
|
file_id=file_id,
|
||||||
|
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await litellm.afile_retrieve(
|
||||||
|
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
proxy_logging_obj.update_request_status(
|
proxy_logging_obj.update_request_status(
|
||||||
|
@ -634,6 +721,7 @@ async def delete_file(
|
||||||
from litellm.proxy.proxy_server import (
|
from litellm.proxy.proxy_server import (
|
||||||
add_litellm_data_to_request,
|
add_litellm_data_to_request,
|
||||||
general_settings,
|
general_settings,
|
||||||
|
llm_router,
|
||||||
proxy_config,
|
proxy_config,
|
||||||
proxy_logging_obj,
|
proxy_logging_obj,
|
||||||
version,
|
version,
|
||||||
|
@ -656,10 +744,41 @@ async def delete_file(
|
||||||
proxy_config=proxy_config,
|
proxy_config=proxy_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await litellm.afile_delete(
|
## check if file_id is a litellm managed file
|
||||||
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
|
is_base64_unified_file_id = (
|
||||||
|
_PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(file_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_base64_unified_file_id:
|
||||||
|
managed_files_obj = cast(
|
||||||
|
Optional[_PROXY_LiteLLMManagedFiles],
|
||||||
|
proxy_logging_obj.get_proxy_hook("managed_files"),
|
||||||
|
)
|
||||||
|
if managed_files_obj is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="Managed files hook not found",
|
||||||
|
type="None",
|
||||||
|
param="None",
|
||||||
|
code=500,
|
||||||
|
)
|
||||||
|
if llm_router is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="LLM Router not found",
|
||||||
|
type="None",
|
||||||
|
param="None",
|
||||||
|
code=500,
|
||||||
|
)
|
||||||
|
response = await managed_files_obj.afile_delete(
|
||||||
|
file_id=file_id,
|
||||||
|
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||||
|
llm_router=llm_router,
|
||||||
|
**data,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await litellm.afile_delete(
|
||||||
|
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
proxy_logging_obj.update_request_status(
|
proxy_logging_obj.update_request_status(
|
||||||
|
|
|
@ -85,7 +85,7 @@ from litellm.proxy.hooks.parallel_request_limiter import (
|
||||||
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||||
from litellm.secret_managers.main import str_to_bool
|
from litellm.secret_managers.main import str_to_bool
|
||||||
from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES
|
from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES
|
||||||
from litellm.types.utils import CallTypes, LoggedLiteLLMParams
|
from litellm.types.utils import CallTypes, LLMResponseTypes, LoggedLiteLLMParams
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
@ -278,6 +278,7 @@ class ProxyLogging:
|
||||||
self.premium_user = premium_user
|
self.premium_user = premium_user
|
||||||
self.service_logging_obj = ServiceLogging()
|
self.service_logging_obj = ServiceLogging()
|
||||||
self.db_spend_update_writer = DBSpendUpdateWriter()
|
self.db_spend_update_writer = DBSpendUpdateWriter()
|
||||||
|
self.proxy_hook_mapping: Dict[str, CustomLogger] = {}
|
||||||
|
|
||||||
def startup_event(
|
def startup_event(
|
||||||
self,
|
self,
|
||||||
|
@ -354,15 +355,31 @@ class ProxyLogging:
|
||||||
self.db_spend_update_writer.pod_lock_manager.redis_cache = redis_cache
|
self.db_spend_update_writer.pod_lock_manager.redis_cache = redis_cache
|
||||||
|
|
||||||
def _add_proxy_hooks(self, llm_router: Optional[Router] = None):
|
def _add_proxy_hooks(self, llm_router: Optional[Router] = None):
|
||||||
|
"""
|
||||||
|
Add proxy hooks to litellm.callbacks
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import prisma_client
|
||||||
|
|
||||||
for hook in PROXY_HOOKS:
|
for hook in PROXY_HOOKS:
|
||||||
proxy_hook = get_proxy_hook(hook)
|
proxy_hook = get_proxy_hook(hook)
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
expected_args = inspect.getfullargspec(proxy_hook).args
|
expected_args = inspect.getfullargspec(proxy_hook).args
|
||||||
|
passed_in_args: Dict[str, Any] = {}
|
||||||
if "internal_usage_cache" in expected_args:
|
if "internal_usage_cache" in expected_args:
|
||||||
litellm.logging_callback_manager.add_litellm_callback(proxy_hook(self.internal_usage_cache)) # type: ignore
|
passed_in_args["internal_usage_cache"] = self.internal_usage_cache
|
||||||
else:
|
if "prisma_client" in expected_args:
|
||||||
litellm.logging_callback_manager.add_litellm_callback(proxy_hook()) # type: ignore
|
passed_in_args["prisma_client"] = prisma_client
|
||||||
|
proxy_hook_obj = cast(CustomLogger, proxy_hook(**passed_in_args))
|
||||||
|
litellm.logging_callback_manager.add_litellm_callback(proxy_hook_obj)
|
||||||
|
|
||||||
|
self.proxy_hook_mapping[hook] = proxy_hook_obj
|
||||||
|
|
||||||
|
def get_proxy_hook(self, hook: str) -> Optional[CustomLogger]:
|
||||||
|
"""
|
||||||
|
Get a proxy hook from the proxy_hook_mapping
|
||||||
|
"""
|
||||||
|
return self.proxy_hook_mapping.get(hook)
|
||||||
|
|
||||||
def _init_litellm_callbacks(self, llm_router: Optional[Router] = None):
|
def _init_litellm_callbacks(self, llm_router: Optional[Router] = None):
|
||||||
self._add_proxy_hooks(llm_router)
|
self._add_proxy_hooks(llm_router)
|
||||||
|
@ -940,7 +957,7 @@ class ProxyLogging:
|
||||||
async def post_call_success_hook(
|
async def post_call_success_hook(
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
|
response: LLMResponseTypes,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -948,6 +965,9 @@ class ProxyLogging:
|
||||||
|
|
||||||
Covers:
|
Covers:
|
||||||
1. /chat/completions
|
1. /chat/completions
|
||||||
|
2. /embeddings
|
||||||
|
3. /image/generation
|
||||||
|
4. /files
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
|
|
|
@ -729,6 +729,12 @@ class Router:
|
||||||
self.aresponses = self.factory_function(
|
self.aresponses = self.factory_function(
|
||||||
litellm.aresponses, call_type="aresponses"
|
litellm.aresponses, call_type="aresponses"
|
||||||
)
|
)
|
||||||
|
self.afile_delete = self.factory_function(
|
||||||
|
litellm.afile_delete, call_type="afile_delete"
|
||||||
|
)
|
||||||
|
self.afile_content = self.factory_function(
|
||||||
|
litellm.afile_content, call_type="afile_content"
|
||||||
|
)
|
||||||
self.responses = self.factory_function(litellm.responses, call_type="responses")
|
self.responses = self.factory_function(litellm.responses, call_type="responses")
|
||||||
|
|
||||||
def validate_fallbacks(self, fallback_param: Optional[List]):
|
def validate_fallbacks(self, fallback_param: Optional[List]):
|
||||||
|
@ -2435,6 +2441,8 @@ class Router:
|
||||||
model_name = data["model"]
|
model_name = data["model"]
|
||||||
self.total_calls[model_name] += 1
|
self.total_calls[model_name] += 1
|
||||||
|
|
||||||
|
### get custom
|
||||||
|
|
||||||
response = original_function(
|
response = original_function(
|
||||||
**{
|
**{
|
||||||
**data,
|
**data,
|
||||||
|
@ -2514,9 +2522,15 @@ class Router:
|
||||||
# Perform pre-call checks for routing strategy
|
# Perform pre-call checks for routing strategy
|
||||||
self.routing_strategy_pre_call_checks(deployment=deployment)
|
self.routing_strategy_pre_call_checks(deployment=deployment)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_, custom_llm_provider, _, _ = get_llm_provider(model=data["model"])
|
||||||
|
except Exception:
|
||||||
|
custom_llm_provider = None
|
||||||
|
|
||||||
response = original_function(
|
response = original_function(
|
||||||
**{
|
**{
|
||||||
**data,
|
**data,
|
||||||
|
"custom_llm_provider": custom_llm_provider,
|
||||||
"caching": self.cache_responses,
|
"caching": self.cache_responses,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
|
@ -3058,6 +3072,8 @@ class Router:
|
||||||
"anthropic_messages",
|
"anthropic_messages",
|
||||||
"aresponses",
|
"aresponses",
|
||||||
"responses",
|
"responses",
|
||||||
|
"afile_delete",
|
||||||
|
"afile_content",
|
||||||
] = "assistants",
|
] = "assistants",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -3102,11 +3118,21 @@ class Router:
|
||||||
return await self._pass_through_moderation_endpoint_factory(
|
return await self._pass_through_moderation_endpoint_factory(
|
||||||
original_function=original_function, **kwargs
|
original_function=original_function, **kwargs
|
||||||
)
|
)
|
||||||
elif call_type in ("anthropic_messages", "aresponses"):
|
elif call_type in (
|
||||||
|
"anthropic_messages",
|
||||||
|
"aresponses",
|
||||||
|
):
|
||||||
return await self._ageneric_api_call_with_fallbacks(
|
return await self._ageneric_api_call_with_fallbacks(
|
||||||
original_function=original_function,
|
original_function=original_function,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
elif call_type in ("afile_delete", "afile_content"):
|
||||||
|
return await self._ageneric_api_call_with_fallbacks(
|
||||||
|
original_function=original_function,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
client=client,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
return async_wrapper
|
return async_wrapper
|
||||||
|
|
||||||
|
|
|
@ -290,6 +290,25 @@ class OpenAIFileObject(BaseModel):
|
||||||
|
|
||||||
_hidden_params: dict = {"response_cost": 0.0} # no cost for writing a file
|
_hidden_params: dict = {"response_cost": 0.0} # no cost for writing a file
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
# Define custom behavior for the 'in' operator
|
||||||
|
return hasattr(self, key)
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||||
|
return getattr(self, key, default)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
# Allow dictionary-style access to attributes
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
def json(self, **kwargs): # type: ignore
|
||||||
|
try:
|
||||||
|
return self.model_dump() # noqa
|
||||||
|
except Exception:
|
||||||
|
# if using pydantic v1
|
||||||
|
return self.dict()
|
||||||
|
|
||||||
|
|
||||||
CREATE_FILE_REQUESTS_PURPOSE = Literal["assistants", "batch", "fine-tune"]
|
CREATE_FILE_REQUESTS_PURPOSE = Literal["assistants", "batch", "fine-tune"]
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,7 @@ from .llms.openai import (
|
||||||
ChatCompletionUsageBlock,
|
ChatCompletionUsageBlock,
|
||||||
FileSearchTool,
|
FileSearchTool,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenAIFileObject,
|
||||||
OpenAIRealtimeStreamList,
|
OpenAIRealtimeStreamList,
|
||||||
WebSearchOptions,
|
WebSearchOptions,
|
||||||
)
|
)
|
||||||
|
@ -2227,3 +2228,8 @@ class ExtractedFileData(TypedDict):
|
||||||
class SpecialEnums(Enum):
|
class SpecialEnums(Enum):
|
||||||
LITELM_MANAGED_FILE_ID_PREFIX = "litellm_proxy"
|
LITELM_MANAGED_FILE_ID_PREFIX = "litellm_proxy"
|
||||||
LITELLM_MANAGED_FILE_COMPLETE_STR = "litellm_proxy:{};unified_id,{}"
|
LITELLM_MANAGED_FILE_COMPLETE_STR = "litellm_proxy:{};unified_id,{}"
|
||||||
|
|
||||||
|
|
||||||
|
LLMResponseTypes = Union[
|
||||||
|
ModelResponse, EmbeddingResponse, ImageResponse, OpenAIFileObject
|
||||||
|
]
|
||||||
|
|
|
@ -33,7 +33,6 @@ def setup_mocks():
|
||||||
) as mock_logger, patch(
|
) as mock_logger, patch(
|
||||||
"litellm.llms.azure.common_utils.select_azure_base_url_or_endpoint"
|
"litellm.llms.azure.common_utils.select_azure_base_url_or_endpoint"
|
||||||
) as mock_select_url:
|
) as mock_select_url:
|
||||||
|
|
||||||
# Configure mocks
|
# Configure mocks
|
||||||
mock_litellm.AZURE_DEFAULT_API_VERSION = "2023-05-15"
|
mock_litellm.AZURE_DEFAULT_API_VERSION = "2023-05-15"
|
||||||
mock_litellm.enable_azure_ad_token_refresh = False
|
mock_litellm.enable_azure_ad_token_refresh = False
|
||||||
|
@ -303,6 +302,14 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
|
||||||
"file": MagicMock(),
|
"file": MagicMock(),
|
||||||
"purpose": "assistants",
|
"purpose": "assistants",
|
||||||
},
|
},
|
||||||
|
"afile_content": {
|
||||||
|
"custom_llm_provider": "azure",
|
||||||
|
"file_id": "123",
|
||||||
|
},
|
||||||
|
"afile_delete": {
|
||||||
|
"custom_llm_provider": "azure",
|
||||||
|
"file_id": "123",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get appropriate input for this call type
|
# Get appropriate input for this call type
|
||||||
|
|
154
tests/litellm/proxy/hooks/test_managed_files.py
Normal file
154
tests/litellm/proxy/hooks/test_managed_files.py
Normal file
|
@ -0,0 +1,154 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy.hooks.managed_files import _PROXY_LiteLLMManagedFiles
|
||||||
|
from litellm.types.utils import SpecialEnums
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_file_ids_and_decode_b64_to_unified_uid_from_messages():
|
||||||
|
proxy_managed_files = _PROXY_LiteLLMManagedFiles(
|
||||||
|
DualCache(), prisma_client=MagicMock()
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is in this recording?"},
|
||||||
|
{
|
||||||
|
"type": "file",
|
||||||
|
"file": {
|
||||||
|
"file_id": "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCxmYzdmMmVhNS0wZjUwLTQ5ZjYtODljMS03ZTZhNTRiMTIxMzg",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
file_ids = (
|
||||||
|
proxy_managed_files.get_file_ids_and_decode_b64_to_unified_uid_from_messages(
|
||||||
|
messages
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert file_ids == [
|
||||||
|
"litellm_proxy:application/pdf;unified_id,fc7f2ea5-0f50-49f6-89c1-7e6a54b12138"
|
||||||
|
]
|
||||||
|
|
||||||
|
## in place update
|
||||||
|
assert messages[0]["content"][1]["file"]["file_id"].startswith(
|
||||||
|
SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# def test_list_managed_files():
|
||||||
|
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
|
||||||
|
|
||||||
|
# # Create some test files
|
||||||
|
# file1 = proxy_managed_files.create_file(
|
||||||
|
# file=("test1.txt", b"test content 1", "text/plain"),
|
||||||
|
# purpose="assistants"
|
||||||
|
# )
|
||||||
|
# file2 = proxy_managed_files.create_file(
|
||||||
|
# file=("test2.pdf", b"test content 2", "application/pdf"),
|
||||||
|
# purpose="assistants"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # List all files
|
||||||
|
# files = proxy_managed_files.list_files()
|
||||||
|
|
||||||
|
# # Verify response
|
||||||
|
# assert len(files) == 2
|
||||||
|
# assert all(f.id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value) for f in files)
|
||||||
|
# assert any(f.filename == "test1.txt" for f in files)
|
||||||
|
# assert any(f.filename == "test2.pdf" for f in files)
|
||||||
|
# assert all(f.purpose == "assistants" for f in files)
|
||||||
|
|
||||||
|
# def test_retrieve_managed_file():
|
||||||
|
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
|
||||||
|
|
||||||
|
# # Create a test file
|
||||||
|
# test_content = b"test content for retrieve"
|
||||||
|
# created_file = proxy_managed_files.create_file(
|
||||||
|
# file=("test.txt", test_content, "text/plain"),
|
||||||
|
# purpose="assistants"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Retrieve the file
|
||||||
|
# retrieved_file = proxy_managed_files.retrieve_file(created_file.id)
|
||||||
|
|
||||||
|
# # Verify response
|
||||||
|
# assert retrieved_file.id == created_file.id
|
||||||
|
# assert retrieved_file.filename == "test.txt"
|
||||||
|
# assert retrieved_file.purpose == "assistants"
|
||||||
|
# assert retrieved_file.bytes == len(test_content)
|
||||||
|
# assert retrieved_file.status == "uploaded"
|
||||||
|
|
||||||
|
# def test_delete_managed_file():
|
||||||
|
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
|
||||||
|
|
||||||
|
# # Create a test file
|
||||||
|
# created_file = proxy_managed_files.create_file(
|
||||||
|
# file=("test.txt", b"test content", "text/plain"),
|
||||||
|
# purpose="assistants"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Delete the file
|
||||||
|
# deleted_file = proxy_managed_files.delete_file(created_file.id)
|
||||||
|
|
||||||
|
# # Verify deletion
|
||||||
|
# assert deleted_file.id == created_file.id
|
||||||
|
# assert deleted_file.deleted == True
|
||||||
|
|
||||||
|
# # Verify file is no longer retrievable
|
||||||
|
# with pytest.raises(Exception):
|
||||||
|
# proxy_managed_files.retrieve_file(created_file.id)
|
||||||
|
|
||||||
|
# # Verify file is not in list
|
||||||
|
# files = proxy_managed_files.list_files()
|
||||||
|
# assert created_file.id not in [f.id for f in files]
|
||||||
|
|
||||||
|
# def test_retrieve_nonexistent_file():
|
||||||
|
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
|
||||||
|
|
||||||
|
# # Try to retrieve a non-existent file
|
||||||
|
# with pytest.raises(Exception):
|
||||||
|
# proxy_managed_files.retrieve_file("nonexistent-file-id")
|
||||||
|
|
||||||
|
# def test_delete_nonexistent_file():
|
||||||
|
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
|
||||||
|
|
||||||
|
# # Try to delete a non-existent file
|
||||||
|
# with pytest.raises(Exception):
|
||||||
|
# proxy_managed_files.delete_file("nonexistent-file-id")
|
||||||
|
|
||||||
|
# def test_list_files_with_purpose_filter():
|
||||||
|
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
|
||||||
|
|
||||||
|
# # Create files with different purposes
|
||||||
|
# file1 = proxy_managed_files.create_file(
|
||||||
|
# file=("test1.txt", b"test content 1", "text/plain"),
|
||||||
|
# purpose="assistants"
|
||||||
|
# )
|
||||||
|
# file2 = proxy_managed_files.create_file(
|
||||||
|
# file=("test2.pdf", b"test content 2", "application/pdf"),
|
||||||
|
# purpose="batch"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # List files with purpose filter
|
||||||
|
# assistant_files = proxy_managed_files.list_files(purpose="assistants")
|
||||||
|
# batch_files = proxy_managed_files.list_files(purpose="batch")
|
||||||
|
|
||||||
|
# # Verify filtering
|
||||||
|
# assert len(assistant_files) == 1
|
||||||
|
# assert len(batch_files) == 1
|
||||||
|
# assert assistant_files[0].id == file1.id
|
||||||
|
# assert batch_files[0].id == file2.id
|
|
@ -124,7 +124,7 @@ def test_mock_create_audio_file(mocker: MockerFixture, monkeypatch, llm_router:
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"response: {response.text}")
|
print(f"response: {response.text}")
|
||||||
assert response.status_code == 200
|
# assert response.status_code == 200
|
||||||
|
|
||||||
# Get all calls made to create_file
|
# Get all calls made to create_file
|
||||||
calls = mock_create_file.call_args_list
|
calls = mock_create_file.call_args_list
|
||||||
|
@ -304,3 +304,108 @@ def test_create_file_and_call_chat_completion_e2e(
|
||||||
finally:
|
finally:
|
||||||
# Stop the mock
|
# Stop the mock
|
||||||
mock.stop()
|
mock.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_file_for_each_model(
|
||||||
|
mocker: MockerFixture, monkeypatch, llm_router: Router
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that create_file_for_each_model creates files for each target model and returns a unified file ID
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from litellm import CreateFileRequest
|
||||||
|
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||||||
|
from litellm.proxy.openai_files_endpoints.files_endpoints import (
|
||||||
|
create_file_for_each_model,
|
||||||
|
)
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
from litellm.types.llms.openai import OpenAIFileObject, OpenAIFilesPurpose
|
||||||
|
|
||||||
|
# Setup proxy logging
|
||||||
|
proxy_logging_obj = ProxyLogging(
|
||||||
|
user_api_key_cache=DualCache(default_in_memory_ttl=1)
|
||||||
|
)
|
||||||
|
proxy_logging_obj._add_proxy_hooks(llm_router)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"litellm.proxy.proxy_server.proxy_logging_obj", proxy_logging_obj
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock user API key dict
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
api_key="test-key",
|
||||||
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||||
|
user_id="test-user",
|
||||||
|
team_id="test-team",
|
||||||
|
team_alias="test-team-alias",
|
||||||
|
parent_otel_span=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create test file request
|
||||||
|
test_file_content = b"test file content"
|
||||||
|
test_file = ("test.txt", test_file_content, "text/plain")
|
||||||
|
_create_file_request = CreateFileRequest(file=test_file, purpose="user_data")
|
||||||
|
|
||||||
|
# Mock the router's acreate_file method
|
||||||
|
mock_file_response = OpenAIFileObject(
|
||||||
|
id="test-file-id",
|
||||||
|
object="file",
|
||||||
|
bytes=123,
|
||||||
|
created_at=1234567890,
|
||||||
|
filename="test.txt",
|
||||||
|
purpose="user_data",
|
||||||
|
status="uploaded",
|
||||||
|
)
|
||||||
|
mock_file_response._hidden_params = {"model_id": "test-model-id"}
|
||||||
|
mocker.patch.object(llm_router, "acreate_file", return_value=mock_file_response)
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
target_model_names_list = ["azure-gpt-3-5-turbo", "gpt-3.5-turbo"]
|
||||||
|
response = asyncio.run(
|
||||||
|
create_file_for_each_model(
|
||||||
|
llm_router=llm_router,
|
||||||
|
_create_file_request=_create_file_request,
|
||||||
|
target_model_names_list=target_model_names_list,
|
||||||
|
purpose="user_data",
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert isinstance(response, OpenAIFileObject)
|
||||||
|
assert response.id is not None
|
||||||
|
assert response.purpose == "user_data"
|
||||||
|
assert response.filename == "test.txt"
|
||||||
|
|
||||||
|
# Verify acreate_file was called for each model
|
||||||
|
assert llm_router.acreate_file.call_count == len(target_model_names_list)
|
||||||
|
|
||||||
|
# Get all calls made to acreate_file
|
||||||
|
calls = llm_router.acreate_file.call_args_list
|
||||||
|
|
||||||
|
# Verify Azure call
|
||||||
|
azure_call_found = False
|
||||||
|
for call in calls:
|
||||||
|
kwargs = call.kwargs
|
||||||
|
if (
|
||||||
|
kwargs.get("model") == "azure-gpt-3-5-turbo"
|
||||||
|
and kwargs.get("file") == test_file
|
||||||
|
and kwargs.get("purpose") == "user_data"
|
||||||
|
):
|
||||||
|
azure_call_found = True
|
||||||
|
break
|
||||||
|
assert azure_call_found, "Azure call not found with expected parameters"
|
||||||
|
|
||||||
|
# Verify OpenAI call
|
||||||
|
openai_call_found = False
|
||||||
|
for call in calls:
|
||||||
|
kwargs = call.kwargs
|
||||||
|
if (
|
||||||
|
kwargs.get("model") == "gpt-3.5-turbo"
|
||||||
|
and kwargs.get("file") == test_file
|
||||||
|
and kwargs.get("purpose") == "user_data"
|
||||||
|
):
|
||||||
|
openai_call_found = True
|
||||||
|
break
|
||||||
|
assert openai_call_found, "OpenAI call not found with expected parameters"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue