mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix: initial commit for litellm_proxy support with CRUD Endpoints
This commit is contained in:
parent
6b04b48b17
commit
7fff83e441
6 changed files with 197 additions and 9 deletions
|
@ -20,8 +20,7 @@ from litellm.types.integrations.argilla import ArgillaItem
|
||||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
AdapterCompletionStreamWrapper,
|
AdapterCompletionStreamWrapper,
|
||||||
EmbeddingResponse,
|
LLMResponseTypes,
|
||||||
ImageResponse,
|
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ModelResponseStream,
|
ModelResponseStream,
|
||||||
StandardCallbackDynamicParams,
|
StandardCallbackDynamicParams,
|
||||||
|
@ -223,7 +222,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
|
response: LLMResponseTypes,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
|
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import uuid
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
|
||||||
|
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
|
@ -18,24 +20,58 @@ from litellm.types.llms.openai import (
|
||||||
OpenAIFileObject,
|
OpenAIFileObject,
|
||||||
OpenAIFilesPurpose,
|
OpenAIFilesPurpose,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import SpecialEnums
|
from litellm.types.utils import LLMResponseTypes, SpecialEnums
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
||||||
|
from litellm.proxy.utils import PrismaClient as _PrismaClient
|
||||||
|
|
||||||
Span = Union[_Span, Any]
|
Span = Union[_Span, Any]
|
||||||
InternalUsageCache = _InternalUsageCache
|
InternalUsageCache = _InternalUsageCache
|
||||||
|
PrismaClient = _PrismaClient
|
||||||
else:
|
else:
|
||||||
Span = Any
|
Span = Any
|
||||||
InternalUsageCache = Any
|
InternalUsageCache = Any
|
||||||
|
PrismaClient = Any
|
||||||
|
|
||||||
|
|
||||||
class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
class BaseFileEndpoints(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def afile_retrieve(
|
||||||
|
self, custom_llm_provider: str, file_id: str, **data: dict
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def afile_list(
|
||||||
|
self, custom_llm_provider: str, **data: dict
|
||||||
|
) -> List[OpenAIFileObject]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def afile_delete(
|
||||||
|
self, custom_llm_provider: str, file_id: str, **data: dict
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(self, internal_usage_cache: InternalUsageCache):
|
def __init__(
|
||||||
|
self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient
|
||||||
|
):
|
||||||
self.internal_usage_cache = internal_usage_cache
|
self.internal_usage_cache = internal_usage_cache
|
||||||
|
self.prisma_client = prisma_client
|
||||||
|
|
||||||
|
async def store_unified_file_id(
|
||||||
|
self, file_id: str, file_object: OpenAIFileObject
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_unified_file_id(self, file_id: str) -> Optional[OpenAIFileObject]:
|
||||||
|
return None
|
||||||
|
|
||||||
async def async_pre_call_hook(
|
async def async_pre_call_hook(
|
||||||
self,
|
self,
|
||||||
|
@ -176,6 +212,17 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
|
|
||||||
return file_id_mapping
|
return file_id_mapping
|
||||||
|
|
||||||
|
async def post_call_success_hook(
|
||||||
|
self,
|
||||||
|
data: Dict,
|
||||||
|
response: LLMResponseTypes,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
) -> Any:
|
||||||
|
if isinstance(response, OpenAIFileObject):
|
||||||
|
asyncio.create_task(self.store_unified_file_id(response.id, response))
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def return_unified_file_id(
|
async def return_unified_file_id(
|
||||||
file_objects: List[OpenAIFileObject],
|
file_objects: List[OpenAIFileObject],
|
||||||
|
@ -228,3 +275,22 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def afile_retrieve(
|
||||||
|
self, custom_llm_provider: str, file_id: str, **data: Dict
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
stored_file_object = await self.get_unified_file_id(file_id)
|
||||||
|
if stored_file_object:
|
||||||
|
return stored_file_object
|
||||||
|
else:
|
||||||
|
raise Exception(f"File object with id={file_id} not found")
|
||||||
|
|
||||||
|
async def afile_list(
|
||||||
|
self, custom_llm_provider: str, **data: Dict
|
||||||
|
) -> List[OpenAIFileObject]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def afile_delete(
|
||||||
|
self, custom_llm_provider: str, file_id: str, **data: Dict
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
raise NotImplementedError("afile_delete not implemented")
|
||||||
|
|
|
@ -311,6 +311,13 @@ async def create_file(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## POST CALL HOOKS ###
|
||||||
|
_response = await proxy_logging_obj.post_call_success_hook(
|
||||||
|
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||||
|
)
|
||||||
|
if _response is not None and isinstance(_response, OpenAIFileObject):
|
||||||
|
response = _response
|
||||||
|
|
||||||
### RESPONSE HEADERS ###
|
### RESPONSE HEADERS ###
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
model_id = hidden_params.get("model_id", None) or ""
|
model_id = hidden_params.get("model_id", None) or ""
|
||||||
|
@ -539,6 +546,8 @@ async def get_file(
|
||||||
version=version,
|
version=version,
|
||||||
proxy_config=proxy_config,
|
proxy_config=proxy_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## check if file_id is a litellm managed file
|
||||||
response = await litellm.afile_retrieve(
|
response = await litellm.afile_retrieve(
|
||||||
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
|
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
|
||||||
)
|
)
|
||||||
|
|
|
@ -85,7 +85,7 @@ from litellm.proxy.hooks.parallel_request_limiter import (
|
||||||
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||||
from litellm.secret_managers.main import str_to_bool
|
from litellm.secret_managers.main import str_to_bool
|
||||||
from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES
|
from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES
|
||||||
from litellm.types.utils import CallTypes, LoggedLiteLLMParams
|
from litellm.types.utils import CallTypes, LLMResponseTypes, LoggedLiteLLMParams
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
@ -940,7 +940,7 @@ class ProxyLogging:
|
||||||
async def post_call_success_hook(
|
async def post_call_success_hook(
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
|
response: LLMResponseTypes,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -948,6 +948,9 @@ class ProxyLogging:
|
||||||
|
|
||||||
Covers:
|
Covers:
|
||||||
1. /chat/completions
|
1. /chat/completions
|
||||||
|
2. /embeddings
|
||||||
|
3. /image/generation
|
||||||
|
4. /files
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
|
|
|
@ -34,6 +34,7 @@ from .llms.openai import (
|
||||||
ChatCompletionUsageBlock,
|
ChatCompletionUsageBlock,
|
||||||
FileSearchTool,
|
FileSearchTool,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenAIFileObject,
|
||||||
OpenAIRealtimeStreamList,
|
OpenAIRealtimeStreamList,
|
||||||
WebSearchOptions,
|
WebSearchOptions,
|
||||||
)
|
)
|
||||||
|
@ -2226,3 +2227,8 @@ class ExtractedFileData(TypedDict):
|
||||||
class SpecialEnums(Enum):
|
class SpecialEnums(Enum):
|
||||||
LITELM_MANAGED_FILE_ID_PREFIX = "litellm_proxy"
|
LITELM_MANAGED_FILE_ID_PREFIX = "litellm_proxy"
|
||||||
LITELLM_MANAGED_FILE_COMPLETE_STR = "litellm_proxy:{};unified_id,{}"
|
LITELLM_MANAGED_FILE_COMPLETE_STR = "litellm_proxy:{};unified_id,{}"
|
||||||
|
|
||||||
|
|
||||||
|
LLMResponseTypes = Union[
|
||||||
|
ModelResponse, EmbeddingResponse, ImageResponse, OpenAIFileObject
|
||||||
|
]
|
||||||
|
|
|
@ -43,3 +43,108 @@ def test_get_file_ids_and_decode_b64_to_unified_uid_from_messages():
|
||||||
assert messages[0]["content"][1]["file"]["file_id"].startswith(
|
assert messages[0]["content"][1]["file"]["file_id"].startswith(
|
||||||
SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value
|
SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# def test_list_managed_files():
|
||||||
|
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
|
||||||
|
|
||||||
|
# # Create some test files
|
||||||
|
# file1 = proxy_managed_files.create_file(
|
||||||
|
# file=("test1.txt", b"test content 1", "text/plain"),
|
||||||
|
# purpose="assistants"
|
||||||
|
# )
|
||||||
|
# file2 = proxy_managed_files.create_file(
|
||||||
|
# file=("test2.pdf", b"test content 2", "application/pdf"),
|
||||||
|
# purpose="assistants"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # List all files
|
||||||
|
# files = proxy_managed_files.list_files()
|
||||||
|
|
||||||
|
# # Verify response
|
||||||
|
# assert len(files) == 2
|
||||||
|
# assert all(f.id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value) for f in files)
|
||||||
|
# assert any(f.filename == "test1.txt" for f in files)
|
||||||
|
# assert any(f.filename == "test2.pdf" for f in files)
|
||||||
|
# assert all(f.purpose == "assistants" for f in files)
|
||||||
|
|
||||||
|
# def test_retrieve_managed_file():
|
||||||
|
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
|
||||||
|
|
||||||
|
# # Create a test file
|
||||||
|
# test_content = b"test content for retrieve"
|
||||||
|
# created_file = proxy_managed_files.create_file(
|
||||||
|
# file=("test.txt", test_content, "text/plain"),
|
||||||
|
# purpose="assistants"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Retrieve the file
|
||||||
|
# retrieved_file = proxy_managed_files.retrieve_file(created_file.id)
|
||||||
|
|
||||||
|
# # Verify response
|
||||||
|
# assert retrieved_file.id == created_file.id
|
||||||
|
# assert retrieved_file.filename == "test.txt"
|
||||||
|
# assert retrieved_file.purpose == "assistants"
|
||||||
|
# assert retrieved_file.bytes == len(test_content)
|
||||||
|
# assert retrieved_file.status == "uploaded"
|
||||||
|
|
||||||
|
# def test_delete_managed_file():
|
||||||
|
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
|
||||||
|
|
||||||
|
# # Create a test file
|
||||||
|
# created_file = proxy_managed_files.create_file(
|
||||||
|
# file=("test.txt", b"test content", "text/plain"),
|
||||||
|
# purpose="assistants"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Delete the file
|
||||||
|
# deleted_file = proxy_managed_files.delete_file(created_file.id)
|
||||||
|
|
||||||
|
# # Verify deletion
|
||||||
|
# assert deleted_file.id == created_file.id
|
||||||
|
# assert deleted_file.deleted == True
|
||||||
|
|
||||||
|
# # Verify file is no longer retrievable
|
||||||
|
# with pytest.raises(Exception):
|
||||||
|
# proxy_managed_files.retrieve_file(created_file.id)
|
||||||
|
|
||||||
|
# # Verify file is not in list
|
||||||
|
# files = proxy_managed_files.list_files()
|
||||||
|
# assert created_file.id not in [f.id for f in files]
|
||||||
|
|
||||||
|
# def test_retrieve_nonexistent_file():
|
||||||
|
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
|
||||||
|
|
||||||
|
# # Try to retrieve a non-existent file
|
||||||
|
# with pytest.raises(Exception):
|
||||||
|
# proxy_managed_files.retrieve_file("nonexistent-file-id")
|
||||||
|
|
||||||
|
# def test_delete_nonexistent_file():
|
||||||
|
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
|
||||||
|
|
||||||
|
# # Try to delete a non-existent file
|
||||||
|
# with pytest.raises(Exception):
|
||||||
|
# proxy_managed_files.delete_file("nonexistent-file-id")
|
||||||
|
|
||||||
|
# def test_list_files_with_purpose_filter():
|
||||||
|
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
|
||||||
|
|
||||||
|
# # Create files with different purposes
|
||||||
|
# file1 = proxy_managed_files.create_file(
|
||||||
|
# file=("test1.txt", b"test content 1", "text/plain"),
|
||||||
|
# purpose="assistants"
|
||||||
|
# )
|
||||||
|
# file2 = proxy_managed_files.create_file(
|
||||||
|
# file=("test2.pdf", b"test content 2", "application/pdf"),
|
||||||
|
# purpose="batch"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # List files with purpose filter
|
||||||
|
# assistant_files = proxy_managed_files.list_files(purpose="assistants")
|
||||||
|
# batch_files = proxy_managed_files.list_files(purpose="batch")
|
||||||
|
|
||||||
|
# # Verify filtering
|
||||||
|
# assert len(assistant_files) == 1
|
||||||
|
# assert len(batch_files) == 1
|
||||||
|
# assert assistant_files[0].id == file1.id
|
||||||
|
# assert batch_files[0].id == file2.id
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue