LiteLLM Minor Fixes & Improvements (04/02/2025) (#9725)

* Add date picker to usage tab + Add reasoning_content token tracking across all providers on streaming (#9722)

* feat(new_usage.tsx): add date picker for new usage tab

allow user to look back on their usage data

* feat(anthropic/chat/transformation.py): report reasoning tokens in completion token details

allows usage tracking on how many reasoning tokens are actually being used

* feat(streaming_chunk_builder.py): return reasoning_tokens in anthropic/openai streaming response

allows tracking reasoning_token usage across providers

* Fix update team metadata + fix bulk adding models on Ui  (#9721)

* fix(handle_add_model_submit.tsx): fix bulk adding models

* fix(team_info.tsx): fix team metadata update

Fixes https://github.com/BerriAI/litellm/issues/9689

* (v0) Unified file id - allow calling multiple providers with same file id (#9718)

* feat(files_endpoints.py): initial commit adding 'target_model_names' support

allow developer to specify all the models they want to call with the file

* feat(files_endpoints.py): return unified files endpoint

* test(test_files_endpoints.py): add validation test - if invalid purpose submitted

* feat: more updates

* feat: initial working commit of unified file id translation

* fix: additional fixes

* fix(router.py): remove model replace logic in jsonl on acreate_file

enables file upload to work for chat completion requests as well

* fix(files_endpoints.py): remove whitespace around model name

* fix(azure/handler.py): return acreate_file with correct response type

* fix: fix linting errors

* test: fix mock test to run on github actions

* fix: fix ruff errors

* fix: fix file too large error

* fix(utils.py): remove redundant var

* test: modify test to work on github actions

* test: update tests

* test: more debug logs to understand ci/cd issue

* test: fix test for respx

* test: skip mock respx test

fails on ci/cd - not clear why

* fix: fix ruff check

* fix: fix test

* fix(model_connection_test.tsx): fix linting error

* test: update unit tests
This commit is contained in:
Krish Dholakia 2025-04-03 11:48:52 -07:00 committed by GitHub
parent ad57b7b331
commit 0ce878e804
27 changed files with 889 additions and 96 deletions

View file

@ -63,16 +63,17 @@ async def acreate_file(
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
kwargs["acreate_file"] = True kwargs["acreate_file"] = True
# Use a partial function to pass your keyword arguments call_args = {
func = partial( "file": file,
create_file, "purpose": purpose,
file, "custom_llm_provider": custom_llm_provider,
purpose, "extra_headers": extra_headers,
custom_llm_provider, "extra_body": extra_body,
extra_headers,
extra_body,
**kwargs, **kwargs,
) }
# Use a partial function to pass your keyword arguments
func = partial(create_file, **call_args)
# Add the context to the function # Add the context to the function
ctx = contextvars.copy_context() ctx = contextvars.copy_context()
@ -92,7 +93,7 @@ async def acreate_file(
def create_file( def create_file(
file: FileTypes, file: FileTypes,
purpose: Literal["assistants", "batch", "fine-tune"], purpose: Literal["assistants", "batch", "fine-tune"],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", custom_llm_provider: Optional[Literal["openai", "azure", "vertex_ai"]] = None,
extra_headers: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None,
**kwargs, **kwargs,
@ -101,6 +102,8 @@ def create_file(
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API. Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
Specify either provider_list or custom_llm_provider.
""" """
try: try:
_is_async = kwargs.pop("acreate_file", False) is True _is_async = kwargs.pop("acreate_file", False) is True
@ -120,7 +123,7 @@ def create_file(
if ( if (
timeout is not None timeout is not None
and isinstance(timeout, httpx.Timeout) and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False and supports_httpx_timeout(cast(str, custom_llm_provider)) is False
): ):
read_timeout = timeout.read or 600 read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout timeout = read_timeout # default 10 min timeout

View file

@ -457,8 +457,12 @@ class Logging(LiteLLMLoggingBaseClass):
non_default_params: dict, non_default_params: dict,
prompt_id: str, prompt_id: str,
prompt_variables: Optional[dict], prompt_variables: Optional[dict],
prompt_management_logger: Optional[CustomLogger] = None,
) -> Tuple[str, List[AllMessageValues], dict]: ) -> Tuple[str, List[AllMessageValues], dict]:
custom_logger = self.get_custom_logger_for_prompt_management(model) custom_logger = (
prompt_management_logger
or self.get_custom_logger_for_prompt_management(model)
)
if custom_logger: if custom_logger:
( (
model, model,

View file

@ -7,6 +7,7 @@ from typing import Dict, List, Literal, Optional, Union, cast
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues, AllMessageValues,
ChatCompletionAssistantMessage, ChatCompletionAssistantMessage,
ChatCompletionFileObject,
ChatCompletionUserMessage, ChatCompletionUserMessage,
) )
from litellm.types.utils import Choices, ModelResponse, StreamingChoices from litellm.types.utils import Choices, ModelResponse, StreamingChoices
@ -292,3 +293,58 @@ def get_completion_messages(
messages, assistant_continue_message, ensure_alternating_roles messages, assistant_continue_message, ensure_alternating_roles
) )
return messages return messages
def get_file_ids_from_messages(messages: List[AllMessageValues]) -> List[str]:
"""
Gets file ids from messages
"""
file_ids = []
for message in messages:
if message.get("role") == "user":
content = message.get("content")
if content:
if isinstance(content, str):
continue
for c in content:
if c["type"] == "file":
file_object = cast(ChatCompletionFileObject, c)
file_object_file_field = file_object["file"]
file_id = file_object_file_field.get("file_id")
if file_id:
file_ids.append(file_id)
return file_ids
def update_messages_with_model_file_ids(
messages: List[AllMessageValues],
model_id: str,
model_file_id_mapping: Dict[str, Dict[str, str]],
) -> List[AllMessageValues]:
"""
Updates messages with model file ids.
model_file_id_mapping: Dict[str, Dict[str, str]] = {
"litellm_proxy/file_id": {
"model_id": "provider_file_id"
}
}
"""
for message in messages:
if message.get("role") == "user":
content = message.get("content")
if content:
if isinstance(content, str):
continue
for c in content:
if c["type"] == "file":
file_object = cast(ChatCompletionFileObject, c)
file_object_file_field = file_object["file"]
file_id = file_object_file_field.get("file_id")
if file_id:
provider_file_id = (
model_file_id_mapping.get(file_id, {}).get(model_id)
or file_id
)
file_object_file_field["file_id"] = provider_file_id
return messages

View file

@ -1,6 +1,6 @@
import base64 import base64
import time import time
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union, cast
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
ChatCompletionAssistantContentValue, ChatCompletionAssistantContentValue,
@ -9,7 +9,9 @@ from litellm.types.llms.openai import (
from litellm.types.utils import ( from litellm.types.utils import (
ChatCompletionAudioResponse, ChatCompletionAudioResponse,
ChatCompletionMessageToolCall, ChatCompletionMessageToolCall,
Choices,
CompletionTokensDetails, CompletionTokensDetails,
CompletionTokensDetailsWrapper,
Function, Function,
FunctionCall, FunctionCall,
ModelResponse, ModelResponse,
@ -203,14 +205,14 @@ class ChunkProcessor:
) )
def get_combined_content( def get_combined_content(
self, chunks: List[Dict[str, Any]] self, chunks: List[Dict[str, Any]], delta_key: str = "content"
) -> ChatCompletionAssistantContentValue: ) -> ChatCompletionAssistantContentValue:
content_list: List[str] = [] content_list: List[str] = []
for chunk in chunks: for chunk in chunks:
choices = chunk["choices"] choices = chunk["choices"]
for choice in choices: for choice in choices:
delta = choice.get("delta", {}) delta = choice.get("delta", {})
content = delta.get("content", "") content = delta.get(delta_key, "")
if content is None: if content is None:
continue # openai v1.0.0 sets content = None for chunks continue # openai v1.0.0 sets content = None for chunks
content_list.append(content) content_list.append(content)
@ -221,6 +223,11 @@ class ChunkProcessor:
# Update the "content" field within the response dictionary # Update the "content" field within the response dictionary
return combined_content return combined_content
def get_combined_reasoning_content(
self, chunks: List[Dict[str, Any]]
) -> ChatCompletionAssistantContentValue:
return self.get_combined_content(chunks, delta_key="reasoning_content")
def get_combined_audio_content( def get_combined_audio_content(
self, chunks: List[Dict[str, Any]] self, chunks: List[Dict[str, Any]]
) -> ChatCompletionAudioResponse: ) -> ChatCompletionAudioResponse:
@ -296,12 +303,27 @@ class ChunkProcessor:
"prompt_tokens_details": prompt_tokens_details, "prompt_tokens_details": prompt_tokens_details,
} }
def count_reasoning_tokens(self, response: ModelResponse) -> int:
reasoning_tokens = 0
for choice in response.choices:
if (
hasattr(cast(Choices, choice).message, "reasoning_content")
and cast(Choices, choice).message.reasoning_content is not None
):
reasoning_tokens += token_counter(
text=cast(Choices, choice).message.reasoning_content,
count_response_tokens=True,
)
return reasoning_tokens
def calculate_usage( def calculate_usage(
self, self,
chunks: List[Union[Dict[str, Any], ModelResponse]], chunks: List[Union[Dict[str, Any], ModelResponse]],
model: str, model: str,
completion_output: str, completion_output: str,
messages: Optional[List] = None, messages: Optional[List] = None,
reasoning_tokens: Optional[int] = None,
) -> Usage: ) -> Usage:
""" """
Calculate usage for the given chunks. Calculate usage for the given chunks.
@ -382,6 +404,19 @@ class ChunkProcessor:
) # for anthropic ) # for anthropic
if completion_tokens_details is not None: if completion_tokens_details is not None:
returned_usage.completion_tokens_details = completion_tokens_details returned_usage.completion_tokens_details = completion_tokens_details
if reasoning_tokens is not None:
if returned_usage.completion_tokens_details is None:
returned_usage.completion_tokens_details = (
CompletionTokensDetailsWrapper(reasoning_tokens=reasoning_tokens)
)
elif (
returned_usage.completion_tokens_details is not None
and returned_usage.completion_tokens_details.reasoning_tokens is None
):
returned_usage.completion_tokens_details.reasoning_tokens = (
reasoning_tokens
)
if prompt_tokens_details is not None: if prompt_tokens_details is not None:
returned_usage.prompt_tokens_details = prompt_tokens_details returned_usage.prompt_tokens_details = prompt_tokens_details

View file

@ -21,7 +21,6 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.types.llms.anthropic import ( from litellm.types.llms.anthropic import (
AnthropicChatCompletionUsageBlock,
ContentBlockDelta, ContentBlockDelta,
ContentBlockStart, ContentBlockStart,
ContentBlockStop, ContentBlockStop,
@ -32,13 +31,13 @@ from litellm.types.llms.anthropic import (
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
ChatCompletionThinkingBlock, ChatCompletionThinkingBlock,
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
) )
from litellm.types.utils import ( from litellm.types.utils import (
Delta, Delta,
GenericStreamingChunk, GenericStreamingChunk,
ModelResponseStream, ModelResponseStream,
StreamingChoices, StreamingChoices,
Usage,
) )
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
@ -487,10 +486,8 @@ class ModelResponseIterator:
return True return True
return False return False
def _handle_usage( def _handle_usage(self, anthropic_usage_chunk: Union[dict, UsageDelta]) -> Usage:
self, anthropic_usage_chunk: Union[dict, UsageDelta] usage_block = Usage(
) -> AnthropicChatCompletionUsageBlock:
usage_block = AnthropicChatCompletionUsageBlock(
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0), prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0), completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
total_tokens=anthropic_usage_chunk.get("input_tokens", 0) total_tokens=anthropic_usage_chunk.get("input_tokens", 0)
@ -581,7 +578,7 @@ class ModelResponseIterator:
text = "" text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None tool_use: Optional[ChatCompletionToolCallChunk] = None
finish_reason = "" finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None usage: Optional[Usage] = None
provider_specific_fields: Dict[str, Any] = {} provider_specific_fields: Dict[str, Any] = {}
reasoning_content: Optional[str] = None reasoning_content: Optional[str] = None
thinking_blocks: Optional[List[ChatCompletionThinkingBlock]] = None thinking_blocks: Optional[List[ChatCompletionThinkingBlock]] = None

View file

@ -33,9 +33,16 @@ from litellm.types.llms.openai import (
ChatCompletionToolCallFunctionChunk, ChatCompletionToolCallFunctionChunk,
ChatCompletionToolParam, ChatCompletionToolParam,
) )
from litellm.types.utils import CompletionTokensDetailsWrapper
from litellm.types.utils import Message as LitellmMessage from litellm.types.utils import Message as LitellmMessage
from litellm.types.utils import PromptTokensDetailsWrapper from litellm.types.utils import PromptTokensDetailsWrapper
from litellm.utils import ModelResponse, Usage, add_dummy_tool, has_tool_call_blocks from litellm.utils import (
ModelResponse,
Usage,
add_dummy_tool,
has_tool_call_blocks,
token_counter,
)
from ..common_utils import AnthropicError, process_anthropic_headers from ..common_utils import AnthropicError, process_anthropic_headers
@ -772,6 +779,15 @@ class AnthropicConfig(BaseConfig):
prompt_tokens_details = PromptTokensDetailsWrapper( prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cache_read_input_tokens cached_tokens=cache_read_input_tokens
) )
completion_token_details = (
CompletionTokensDetailsWrapper(
reasoning_tokens=token_counter(
text=reasoning_content, count_response_tokens=True
)
)
if reasoning_content
else None
)
total_tokens = prompt_tokens + completion_tokens total_tokens = prompt_tokens + completion_tokens
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
@ -780,6 +796,7 @@ class AnthropicConfig(BaseConfig):
prompt_tokens_details=prompt_tokens_details, prompt_tokens_details=prompt_tokens_details,
cache_creation_input_tokens=cache_creation_input_tokens, cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens, cache_read_input_tokens=cache_read_input_tokens,
completion_tokens_details=completion_token_details,
) )
setattr(model_response, "usage", usage) # type: ignore setattr(model_response, "usage", usage) # type: ignore

View file

@ -28,11 +28,11 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
self, self,
create_file_data: CreateFileRequest, create_file_data: CreateFileRequest,
openai_client: AsyncAzureOpenAI, openai_client: AsyncAzureOpenAI,
) -> FileObject: ) -> OpenAIFileObject:
verbose_logger.debug("create_file_data=%s", create_file_data) verbose_logger.debug("create_file_data=%s", create_file_data)
response = await openai_client.files.create(**create_file_data) response = await openai_client.files.create(**create_file_data)
verbose_logger.debug("create_file_response=%s", response) verbose_logger.debug("create_file_response=%s", response)
return response return OpenAIFileObject(**response.model_dump())
def create_file( def create_file(
self, self,
@ -66,7 +66,7 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
raise ValueError( raise ValueError(
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client." "AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
) )
return self.acreate_file( # type: ignore return self.acreate_file(
create_file_data=create_file_data, openai_client=openai_client create_file_data=create_file_data, openai_client=openai_client
) )
response = cast(AzureOpenAI, openai_client).files.create(**create_file_data) response = cast(AzureOpenAI, openai_client).files.create(**create_file_data)

View file

@ -110,7 +110,10 @@ from .litellm_core_utils.fallback_utils import (
async_completion_with_fallbacks, async_completion_with_fallbacks,
completion_with_fallbacks, completion_with_fallbacks,
) )
from .litellm_core_utils.prompt_templates.common_utils import get_completion_messages from .litellm_core_utils.prompt_templates.common_utils import (
get_completion_messages,
update_messages_with_model_file_ids,
)
from .litellm_core_utils.prompt_templates.factory import ( from .litellm_core_utils.prompt_templates.factory import (
custom_prompt, custom_prompt,
function_call_prompt, function_call_prompt,
@ -953,7 +956,6 @@ def completion( # type: ignore # noqa: PLR0915
non_default_params = get_non_default_completion_params(kwargs=kwargs) non_default_params = get_non_default_completion_params(kwargs=kwargs)
litellm_params = {} # used to prevent unbound var errors litellm_params = {} # used to prevent unbound var errors
## PROMPT MANAGEMENT HOOKS ## ## PROMPT MANAGEMENT HOOKS ##
if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None: if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None:
( (
model, model,
@ -1068,6 +1070,15 @@ def completion( # type: ignore # noqa: PLR0915
if eos_token: if eos_token:
custom_prompt_dict[model]["eos_token"] = eos_token custom_prompt_dict[model]["eos_token"] = eos_token
if kwargs.get("model_file_id_mapping"):
messages = update_messages_with_model_file_ids(
messages=messages,
model_id=kwargs.get("model_info", {}).get("id", None),
model_file_id_mapping=cast(
Dict[str, Dict[str, str]], kwargs.get("model_file_id_mapping")
),
)
provider_config: Optional[BaseConfig] = None provider_config: Optional[BaseConfig] = None
if custom_llm_provider is not None and custom_llm_provider in [ if custom_llm_provider is not None and custom_llm_provider in [
provider.value for provider in LlmProviders provider.value for provider in LlmProviders
@ -5799,6 +5810,19 @@ def stream_chunk_builder( # noqa: PLR0915
"content" "content"
] = processor.get_combined_content(content_chunks) ] = processor.get_combined_content(content_chunks)
reasoning_chunks = [
chunk
for chunk in chunks
if len(chunk["choices"]) > 0
and "reasoning_content" in chunk["choices"][0]["delta"]
and chunk["choices"][0]["delta"]["reasoning_content"] is not None
]
if len(reasoning_chunks) > 0:
response["choices"][0]["message"][
"reasoning_content"
] = processor.get_combined_reasoning_content(reasoning_chunks)
audio_chunks = [ audio_chunks = [
chunk chunk
for chunk in chunks for chunk in chunks
@ -5813,11 +5837,14 @@ def stream_chunk_builder( # noqa: PLR0915
completion_output = get_content_from_model_response(response) completion_output = get_content_from_model_response(response)
reasoning_tokens = processor.count_reasoning_tokens(response)
usage = processor.calculate_usage( usage = processor.calculate_usage(
chunks=chunks, chunks=chunks,
model=model, model=model,
completion_output=completion_output, completion_output=completion_output,
messages=messages, messages=messages,
reasoning_tokens=reasoning_tokens,
) )
setattr(response, "usage", usage) setattr(response, "usage", usage)

View file

@ -1,18 +1,17 @@
model_list: model_list:
- model_name: "gpt-4o" - model_name: "gpt-4o-azure"
litellm_params: litellm_params:
model: azure/chatgpt-v-2 model: azure/gpt-4o
api_key: os.environ/AZURE_API_KEY api_key: os.environ/AZURE_API_KEY
api_base: http://0.0.0.0:8090 api_base: os.environ/AZURE_API_BASE
rpm: 3
- model_name: "gpt-4o-mini-openai" - model_name: "gpt-4o-mini-openai"
litellm_params: litellm_params:
model: gpt-4o-mini model: gpt-4o-mini
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
- model_name: "openai/*" - model_name: "openai/*"
litellm_params: litellm_params:
model: openai/* model: openai/*
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
- model_name: "bedrock-nova" - model_name: "bedrock-nova"
litellm_params: litellm_params:
model: us.amazon.nova-pro-v1:0 model: us.amazon.nova-pro-v1:0

View file

@ -2688,6 +2688,10 @@ class PrismaCompatibleUpdateDBModel(TypedDict, total=False):
updated_by: str updated_by: str
class SpecialEnums(enum.Enum):
LITELM_MANAGED_FILE_ID_PREFIX = "litellm_proxy/"
class SpecialManagementEndpointEnums(enum.Enum): class SpecialManagementEndpointEnums(enum.Enum):
DEFAULT_ORGANIZATION = "default_organization" DEFAULT_ORGANIZATION = "default_organization"

View file

@ -1 +1,36 @@
from typing import Literal, Union
from . import * from . import *
from .cache_control_check import _PROXY_CacheControlCheck
from .managed_files import _PROXY_LiteLLMManagedFiles
from .max_budget_limiter import _PROXY_MaxBudgetLimiter
from .parallel_request_limiter import _PROXY_MaxParallelRequestsHandler
# List of all available hooks that can be enabled
PROXY_HOOKS = {
"max_budget_limiter": _PROXY_MaxBudgetLimiter,
"managed_files": _PROXY_LiteLLMManagedFiles,
"parallel_request_limiter": _PROXY_MaxParallelRequestsHandler,
"cache_control_check": _PROXY_CacheControlCheck,
}
def get_proxy_hook(
hook_name: Union[
Literal[
"max_budget_limiter",
"managed_files",
"parallel_request_limiter",
"cache_control_check",
],
str,
]
):
"""
Factory method to get a proxy hook instance by name
"""
if hook_name not in PROXY_HOOKS:
raise ValueError(
f"Unknown hook: {hook_name}. Available hooks: {list(PROXY_HOOKS.keys())}"
)
return PROXY_HOOKS[hook_name]

View file

@ -0,0 +1,145 @@
# What is this?
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
import uuid
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union, cast
from litellm import verbose_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_file_ids_from_messages,
)
from litellm.proxy._types import CallTypes, SpecialEnums, UserAPIKeyAuth
from litellm.types.llms.openai import OpenAIFileObject, OpenAIFilesPurpose
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
Span = Union[_Span, Any]
InternalUsageCache = _InternalUsageCache
else:
Span = Any
InternalUsageCache = Any
class _PROXY_LiteLLMManagedFiles(CustomLogger):
# Class variables or attributes
def __init__(self, internal_usage_cache: InternalUsageCache):
self.internal_usage_cache = internal_usage_cache
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: Dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
"rerank",
],
) -> Union[Exception, str, Dict, None]:
"""
- Detect litellm_proxy/ file_id
- add dictionary of mappings of litellm_proxy/ file_id -> provider_file_id => {litellm_proxy/file_id: {"model_id": id, "file_id": provider_file_id}}
"""
if call_type == CallTypes.completion.value:
messages = data.get("messages")
if messages:
file_ids = get_file_ids_from_messages(messages)
if file_ids:
model_file_id_mapping = await self.get_model_file_id_mapping(
file_ids, user_api_key_dict.parent_otel_span
)
data["model_file_id_mapping"] = model_file_id_mapping
return data
async def get_model_file_id_mapping(
self, file_ids: List[str], litellm_parent_otel_span: Span
) -> dict:
"""
Get model-specific file IDs for a list of proxy file IDs.
Returns a dictionary mapping litellm_proxy/ file_id -> model_id -> model_file_id
1. Get all the litellm_proxy/ file_ids from the messages
2. For each file_id, search for cache keys matching the pattern file_id:*
3. Return a dictionary of mappings of litellm_proxy/ file_id -> model_id -> model_file_id
Example:
{
"litellm_proxy/file_id": {
"model_id": "model_file_id"
}
}
"""
file_id_mapping: Dict[str, Dict[str, str]] = {}
litellm_managed_file_ids = []
for file_id in file_ids:
## CHECK IF FILE ID IS MANAGED BY LITELM
if file_id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
litellm_managed_file_ids.append(file_id)
if litellm_managed_file_ids:
# Get all cache keys matching the pattern file_id:*
for file_id in litellm_managed_file_ids:
# Search for any cache key starting with this file_id
cached_values = cast(
Dict[str, str],
await self.internal_usage_cache.async_get_cache(
key=file_id, litellm_parent_otel_span=litellm_parent_otel_span
),
)
if cached_values:
file_id_mapping[file_id] = cached_values
return file_id_mapping
@staticmethod
async def return_unified_file_id(
file_objects: List[OpenAIFileObject],
purpose: OpenAIFilesPurpose,
internal_usage_cache: InternalUsageCache,
litellm_parent_otel_span: Span,
) -> OpenAIFileObject:
unified_file_id = SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value + str(
uuid.uuid4()
)
## CREATE RESPONSE OBJECT
response = OpenAIFileObject(
id=unified_file_id,
object="file",
purpose=cast(OpenAIFilesPurpose, purpose),
created_at=file_objects[0].created_at,
bytes=1234,
filename=str(datetime.now().timestamp()),
status="uploaded",
)
## STORE RESPONSE IN DB + CACHE
stored_values: Dict[str, str] = {}
for file_object in file_objects:
model_id = file_object._hidden_params.get("model_id")
if model_id is None:
verbose_logger.warning(
f"Skipping file_object: {file_object} because model_id in hidden_params={file_object._hidden_params} is None"
)
continue
file_id = file_object.id
stored_values[model_id] = file_id
await internal_usage_cache.async_set_cache(
key=unified_file_id,
value=stored_values,
litellm_parent_otel_span=litellm_parent_otel_span,
)
return response

View file

@ -7,7 +7,7 @@
import asyncio import asyncio
import traceback import traceback
from typing import Optional from typing import Optional, cast, get_args
import httpx import httpx
from fastapi import ( from fastapi import (
@ -31,7 +31,10 @@ from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessin
from litellm.proxy.common_utils.openai_endpoint_utils import ( from litellm.proxy.common_utils.openai_endpoint_utils import (
get_custom_llm_provider_from_request_body, get_custom_llm_provider_from_request_body,
) )
from litellm.proxy.hooks.managed_files import _PROXY_LiteLLMManagedFiles
from litellm.proxy.utils import ProxyLogging
from litellm.router import Router from litellm.router import Router
from litellm.types.llms.openai import OpenAIFileObject, OpenAIFilesPurpose
router = APIRouter() router = APIRouter()
@ -104,6 +107,53 @@ def is_known_model(model: Optional[str], llm_router: Optional[Router]) -> bool:
return is_in_list return is_in_list
async def _deprecated_loadbalanced_create_file(
llm_router: Optional[Router],
router_model: str,
_create_file_request: CreateFileRequest,
) -> OpenAIFileObject:
if llm_router is None:
raise HTTPException(
status_code=500,
detail={
"error": "LLM Router not initialized. Ensure models added to proxy."
},
)
response = await llm_router.acreate_file(model=router_model, **_create_file_request)
return response
async def create_file_for_each_model(
llm_router: Optional[Router],
_create_file_request: CreateFileRequest,
target_model_names_list: List[str],
purpose: OpenAIFilesPurpose,
proxy_logging_obj: ProxyLogging,
user_api_key_dict: UserAPIKeyAuth,
) -> OpenAIFileObject:
if llm_router is None:
raise HTTPException(
status_code=500,
detail={
"error": "LLM Router not initialized. Ensure models added to proxy."
},
)
responses = []
for model in target_model_names_list:
individual_response = await llm_router.acreate_file(
model=model, **_create_file_request
)
responses.append(individual_response)
response = await _PROXY_LiteLLMManagedFiles.return_unified_file_id(
file_objects=responses,
purpose=purpose,
internal_usage_cache=proxy_logging_obj.internal_usage_cache,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
)
return response
@router.post( @router.post(
"/{provider}/v1/files", "/{provider}/v1/files",
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
@ -123,6 +173,7 @@ async def create_file(
request: Request, request: Request,
fastapi_response: Response, fastapi_response: Response,
purpose: str = Form(...), purpose: str = Form(...),
target_model_names: str = Form(default=""),
provider: Optional[str] = None, provider: Optional[str] = None,
custom_llm_provider: str = Form(default="openai"), custom_llm_provider: str = Form(default="openai"),
file: UploadFile = File(...), file: UploadFile = File(...),
@ -162,8 +213,25 @@ async def create_file(
or await get_custom_llm_provider_from_request_body(request=request) or await get_custom_llm_provider_from_request_body(request=request)
or "openai" or "openai"
) )
target_model_names_list = (
target_model_names.split(",") if target_model_names else []
)
target_model_names_list = [model.strip() for model in target_model_names_list]
# Prepare the data for forwarding # Prepare the data for forwarding
# Replace with:
valid_purposes = get_args(OpenAIFilesPurpose)
if purpose not in valid_purposes:
raise HTTPException(
status_code=400,
detail={
"error": f"Invalid purpose: {purpose}. Must be one of: {valid_purposes}",
},
)
# Cast purpose to OpenAIFilesPurpose type
purpose = cast(OpenAIFilesPurpose, purpose)
data = {"purpose": purpose} data = {"purpose": purpose}
# Include original request and headers in the data # Include original request and headers in the data
@ -192,21 +260,25 @@ async def create_file(
_create_file_request = CreateFileRequest(file=file_data, **data) _create_file_request = CreateFileRequest(file=file_data, **data)
response: Optional[OpenAIFileObject] = None
if ( if (
litellm.enable_loadbalancing_on_batch_endpoints is True litellm.enable_loadbalancing_on_batch_endpoints is True
and is_router_model and is_router_model
and router_model is not None and router_model is not None
): ):
if llm_router is None: response = await _deprecated_loadbalanced_create_file(
raise HTTPException( llm_router=llm_router,
status_code=500, router_model=router_model,
detail={ _create_file_request=_create_file_request,
"error": "LLM Router not initialized. Ensure models added to proxy." )
}, elif target_model_names_list:
) response = await create_file_for_each_model(
llm_router=llm_router,
response = await llm_router.acreate_file( _create_file_request=_create_file_request,
model=router_model, **_create_file_request target_model_names_list=target_model_names_list,
purpose=purpose,
proxy_logging_obj=proxy_logging_obj,
user_api_key_dict=user_api_key_dict,
) )
else: else:
# get configs for custom_llm_provider # get configs for custom_llm_provider
@ -220,6 +292,11 @@ async def create_file(
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
response = await litellm.acreate_file(**_create_file_request, custom_llm_provider=custom_llm_provider) # type: ignore response = await litellm.acreate_file(**_create_file_request, custom_llm_provider=custom_llm_provider) # type: ignore
if response is None:
raise HTTPException(
status_code=500,
detail={"error": "Failed to create file. Please try again."},
)
### ALERTING ### ### ALERTING ###
asyncio.create_task( asyncio.create_task(
proxy_logging_obj.update_request_status( proxy_logging_obj.update_request_status(
@ -248,12 +325,11 @@ async def create_file(
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
) )
verbose_proxy_logger.error( verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.create_file(): Exception occured - {}".format( "litellm.proxy.proxy_server.create_file(): Exception occured - {}".format(
str(e) str(e)
) )
) )
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "message", str(e.detail)), message=getattr(e, "message", str(e.detail)),

View file

@ -76,6 +76,7 @@ from litellm.proxy.db.create_views import (
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
from litellm.proxy.db.log_db_metrics import log_db_metrics from litellm.proxy.db.log_db_metrics import log_db_metrics
from litellm.proxy.db.prisma_client import PrismaWrapper from litellm.proxy.db.prisma_client import PrismaWrapper
from litellm.proxy.hooks import PROXY_HOOKS, get_proxy_hook
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.parallel_request_limiter import ( from litellm.proxy.hooks.parallel_request_limiter import (
@ -352,10 +353,19 @@ class ProxyLogging:
self.db_spend_update_writer.redis_update_buffer.redis_cache = redis_cache self.db_spend_update_writer.redis_update_buffer.redis_cache = redis_cache
self.db_spend_update_writer.pod_lock_manager.redis_cache = redis_cache self.db_spend_update_writer.pod_lock_manager.redis_cache = redis_cache
def _add_proxy_hooks(self, llm_router: Optional[Router] = None):
for hook in PROXY_HOOKS:
proxy_hook = get_proxy_hook(hook)
import inspect
expected_args = inspect.getfullargspec(proxy_hook).args
if "internal_usage_cache" in expected_args:
litellm.logging_callback_manager.add_litellm_callback(proxy_hook(self.internal_usage_cache)) # type: ignore
else:
litellm.logging_callback_manager.add_litellm_callback(proxy_hook()) # type: ignore
def _init_litellm_callbacks(self, llm_router: Optional[Router] = None): def _init_litellm_callbacks(self, llm_router: Optional[Router] = None):
litellm.logging_callback_manager.add_litellm_callback(self.max_parallel_request_limiter) # type: ignore self._add_proxy_hooks(llm_router)
litellm.logging_callback_manager.add_litellm_callback(self.max_budget_limiter) # type: ignore
litellm.logging_callback_manager.add_litellm_callback(self.cache_control_check) # type: ignore
litellm.logging_callback_manager.add_litellm_callback(self.service_logging_obj) # type: ignore litellm.logging_callback_manager.add_litellm_callback(self.service_logging_obj) # type: ignore
for callback in litellm.callbacks: for callback in litellm.callbacks:
if isinstance(callback, str): if isinstance(callback, str):

View file

@ -68,10 +68,7 @@ from litellm.router_utils.add_retry_fallback_headers import (
add_fallback_headers_to_response, add_fallback_headers_to_response,
add_retry_headers_to_response, add_retry_headers_to_response,
) )
from litellm.router_utils.batch_utils import ( from litellm.router_utils.batch_utils import _get_router_metadata_variable_name
_get_router_metadata_variable_name,
replace_model_in_jsonl,
)
from litellm.router_utils.client_initalization_utils import InitalizeCachedClient from litellm.router_utils.client_initalization_utils import InitalizeCachedClient
from litellm.router_utils.clientside_credential_handler import ( from litellm.router_utils.clientside_credential_handler import (
get_dynamic_litellm_params, get_dynamic_litellm_params,
@ -105,7 +102,12 @@ from litellm.router_utils.router_callbacks.track_deployment_metrics import (
increment_deployment_successes_for_current_minute, increment_deployment_successes_for_current_minute,
) )
from litellm.scheduler import FlowItem, Scheduler from litellm.scheduler import FlowItem, Scheduler
from litellm.types.llms.openai import AllMessageValues, Batch, FileObject, FileTypes from litellm.types.llms.openai import (
AllMessageValues,
Batch,
FileTypes,
OpenAIFileObject,
)
from litellm.types.router import ( from litellm.types.router import (
CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS, CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
VALID_LITELLM_ENVIRONMENTS, VALID_LITELLM_ENVIRONMENTS,
@ -2703,7 +2705,7 @@ class Router:
self, self,
model: str, model: str,
**kwargs, **kwargs,
) -> FileObject: ) -> OpenAIFileObject:
try: try:
kwargs["model"] = model kwargs["model"] = model
kwargs["original_function"] = self._acreate_file kwargs["original_function"] = self._acreate_file
@ -2727,7 +2729,7 @@ class Router:
self, self,
model: str, model: str,
**kwargs, **kwargs,
) -> FileObject: ) -> OpenAIFileObject:
try: try:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}" f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
@ -2754,9 +2756,9 @@ class Router:
stripped_model, custom_llm_provider, _, _ = get_llm_provider( stripped_model, custom_llm_provider, _, _ = get_llm_provider(
model=data["model"] model=data["model"]
) )
kwargs["file"] = replace_model_in_jsonl( # kwargs["file"] = replace_model_in_jsonl(
file_content=kwargs["file"], new_model_name=stripped_model # file_content=kwargs["file"], new_model_name=stripped_model
) # )
response = litellm.acreate_file( response = litellm.acreate_file(
**{ **{
@ -2796,6 +2798,7 @@ class Router:
verbose_router_logger.info( verbose_router_logger.info(
f"litellm.acreate_file(model={model_name})\033[32m 200 OK\033[0m" f"litellm.acreate_file(model={model_name})\033[32m 200 OK\033[0m"
) )
return response # type: ignore return response # type: ignore
except Exception as e: except Exception as e:
verbose_router_logger.exception( verbose_router_logger.exception(

View file

@ -0,0 +1,14 @@
import hashlib
import json
from litellm.types.router import CredentialLiteLLMParams
def get_litellm_params_sensitive_credential_hash(litellm_params: dict) -> str:
"""
Hash of the credential params, used for mapping the file id to the right model
"""
sensitive_params = CredentialLiteLLMParams(**litellm_params)
return hashlib.sha256(
json.dumps(sensitive_params.model_dump()).encode()
).hexdigest()

View file

@ -234,7 +234,18 @@ class Thread(BaseModel):
"""The object type, which is always `thread`.""" """The object type, which is always `thread`."""
OpenAICreateFileRequestOptionalParams = Literal["purpose",] OpenAICreateFileRequestOptionalParams = Literal["purpose"]
OpenAIFilesPurpose = Literal[
"assistants",
"assistants_output",
"batch",
"batch_output",
"fine-tune",
"fine-tune-results",
"vision",
"user_data",
]
class OpenAIFileObject(BaseModel): class OpenAIFileObject(BaseModel):
@ -253,16 +264,7 @@ class OpenAIFileObject(BaseModel):
object: Literal["file"] object: Literal["file"]
"""The object type, which is always `file`.""" """The object type, which is always `file`."""
purpose: Literal[ purpose: OpenAIFilesPurpose
"assistants",
"assistants_output",
"batch",
"batch_output",
"fine-tune",
"fine-tune-results",
"vision",
"user_data",
]
"""The intended purpose of the file. """The intended purpose of the file.
Supported values are `assistants`, `assistants_output`, `batch`, `batch_output`, Supported values are `assistants`, `assistants_output`, `batch`, `batch_output`,
@ -286,6 +288,8 @@ class OpenAIFileObject(BaseModel):
`error` field on `fine_tuning.job`. `error` field on `fine_tuning.job`.
""" """
_hidden_params: dict = {}
# OpenAI Files Types # OpenAI Files Types
class CreateFileRequest(TypedDict, total=False): class CreateFileRequest(TypedDict, total=False):

View file

@ -18,6 +18,7 @@ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from ..exceptions import RateLimitError from ..exceptions import RateLimitError
from .completion import CompletionRequest from .completion import CompletionRequest
from .embedding import EmbeddingRequest from .embedding import EmbeddingRequest
from .llms.openai import OpenAIFileObject
from .llms.vertex_ai import VERTEX_CREDENTIALS_TYPES from .llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from .utils import ModelResponse, ProviderSpecificModelInfo from .utils import ModelResponse, ProviderSpecificModelInfo
@ -703,3 +704,12 @@ class GenericBudgetWindowDetails(BaseModel):
OptionalPreCallChecks = List[Literal["prompt_caching", "router_budget_limiting"]] OptionalPreCallChecks = List[Literal["prompt_caching", "router_budget_limiting"]]
class LiteLLM_RouterFileObject(TypedDict, total=False):
"""
Tracking the litellm params hash, used for mapping the file id to the right model
"""
litellm_params_sensitive_credential_hash: str
file_object: OpenAIFileObject

View file

@ -1886,6 +1886,7 @@ all_litellm_params = [
"logger_fn", "logger_fn",
"verbose", "verbose",
"custom_llm_provider", "custom_llm_provider",
"model_file_id_mapping",
"litellm_logging_obj", "litellm_logging_obj",
"litellm_call_id", "litellm_call_id",
"use_client", "use_client",

View file

@ -17,6 +17,7 @@ import litellm
from litellm.litellm_core_utils.litellm_logging import Logging from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.types.utils import ( from litellm.types.utils import (
CompletionTokensDetailsWrapper,
Delta, Delta,
ModelResponseStream, ModelResponseStream,
PromptTokensDetailsWrapper, PromptTokensDetailsWrapper,
@ -430,11 +431,18 @@ async def test_streaming_handler_with_usage(
completion_tokens=392, completion_tokens=392,
prompt_tokens=1799, prompt_tokens=1799,
total_tokens=2191, total_tokens=2191,
completion_tokens_details=None, completion_tokens_details=CompletionTokensDetailsWrapper( # <-- This has a value
accepted_prediction_tokens=None,
audio_tokens=None,
reasoning_tokens=0,
rejected_prediction_tokens=None,
text_tokens=None,
),
prompt_tokens_details=PromptTokensDetailsWrapper( prompt_tokens_details=PromptTokensDetailsWrapper(
audio_tokens=None, cached_tokens=1796, text_tokens=None, image_tokens=None audio_tokens=None, cached_tokens=1796, text_tokens=None, image_tokens=None
), ),
) )
final_chunk = ModelResponseStream( final_chunk = ModelResponseStream(
id="chatcmpl-87291500-d8c5-428e-b187-36fe5a4c97ab", id="chatcmpl-87291500-d8c5-428e-b187-36fe5a4c97ab",
created=1742056047, created=1742056047,
@ -510,7 +518,13 @@ async def test_streaming_with_usage_and_logging(sync_mode: bool):
completion_tokens=392, completion_tokens=392,
prompt_tokens=1799, prompt_tokens=1799,
total_tokens=2191, total_tokens=2191,
completion_tokens_details=None, completion_tokens_details=CompletionTokensDetailsWrapper(
accepted_prediction_tokens=None,
audio_tokens=None,
reasoning_tokens=0,
rejected_prediction_tokens=None,
text_tokens=None,
),
prompt_tokens_details=PromptTokensDetailsWrapper( prompt_tokens_details=PromptTokensDetailsWrapper(
audio_tokens=None, audio_tokens=None,
cached_tokens=1796, cached_tokens=1796,

View file

@ -0,0 +1,306 @@
import json
import os
import sys
from unittest.mock import ANY
import pytest
import respx
from fastapi.testclient import TestClient
from pytest_mock import MockerFixture
sys.path.insert(
0, os.path.abspath("../../../..")
) # Adds the parent directory to the system path
import litellm
from litellm import Router
from litellm.proxy._types import LiteLLM_UserTableFiltered, UserAPIKeyAuth
from litellm.proxy.hooks import get_proxy_hook
from litellm.proxy.management_endpoints.internal_user_endpoints import ui_view_users
from litellm.proxy.proxy_server import app
client = TestClient(app)
from litellm.caching.caching import DualCache
from litellm.proxy.proxy_server import hash_token
from litellm.proxy.utils import ProxyLogging
@pytest.fixture
def llm_router() -> Router:
llm_router = Router(
model_list=[
{
"model_name": "azure-gpt-3-5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": "azure_api_key",
"api_base": "azure_api_base",
"api_version": "azure_api_version",
},
"model_info": {
"id": "azure-gpt-3-5-turbo-id",
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "openai/gpt-3.5-turbo",
"api_key": "openai_api_key",
},
"model_info": {
"id": "gpt-3.5-turbo-id",
},
},
{
"model_name": "gemini-2.0-flash",
"litellm_params": {
"model": "gemini/gemini-2.0-flash",
},
"model_info": {
"id": "gemini-2.0-flash-id",
},
},
]
)
return llm_router
def setup_proxy_logging_object(monkeypatch, llm_router: Router) -> ProxyLogging:
proxy_logging_object = ProxyLogging(
user_api_key_cache=DualCache(default_in_memory_ttl=1)
)
proxy_logging_object._add_proxy_hooks(llm_router)
monkeypatch.setattr(
"litellm.proxy.proxy_server.proxy_logging_obj", proxy_logging_object
)
return proxy_logging_object
def test_invalid_purpose(mocker: MockerFixture, monkeypatch, llm_router: Router):
"""
Asserts 'create_file' is called with the correct arguments
"""
# Create a simple test file content
test_file_content = b"test audio content"
test_file = ("test.wav", test_file_content, "audio/wav")
response = client.post(
"/v1/files",
files={"file": test_file},
data={
"purpose": "my-bad-purpose",
"target_model_names": ["azure-gpt-3-5-turbo", "gpt-3.5-turbo"],
},
headers={"Authorization": "Bearer test-key"},
)
assert response.status_code == 400
print(f"response: {response.json()}")
assert "Invalid purpose: my-bad-purpose" in response.json()["error"]["message"]
def test_mock_create_audio_file(mocker: MockerFixture, monkeypatch, llm_router: Router):
"""
Asserts 'create_file' is called with the correct arguments
"""
from litellm import Router
mock_create_file = mocker.patch("litellm.files.main.create_file")
monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", llm_router)
# Create a simple test file content
test_file_content = b"test audio content"
test_file = ("test.wav", test_file_content, "audio/wav")
response = client.post(
"/v1/files",
files={"file": test_file},
data={
"purpose": "user_data",
"target_model_names": "azure-gpt-3-5-turbo, gpt-3.5-turbo",
},
headers={"Authorization": "Bearer test-key"},
)
print(f"response: {response.text}")
assert response.status_code == 200
# Get all calls made to create_file
calls = mock_create_file.call_args_list
# Check for Azure call
azure_call_found = False
for call in calls:
kwargs = call.kwargs
if (
kwargs.get("custom_llm_provider") == "azure"
and kwargs.get("model") == "azure/chatgpt-v-2"
and kwargs.get("api_key") == "azure_api_key"
):
azure_call_found = True
break
assert (
azure_call_found
), f"Azure call not found with expected parameters. Calls: {calls}"
# Check for OpenAI call
openai_call_found = False
for call in calls:
kwargs = call.kwargs
if (
kwargs.get("custom_llm_provider") == "openai"
and kwargs.get("model") == "openai/gpt-3.5-turbo"
and kwargs.get("api_key") == "openai_api_key"
):
openai_call_found = True
break
assert openai_call_found, "OpenAI call not found with expected parameters"
@pytest.mark.skip(reason="mock respx fails on ci/cd - unclear why")
def test_create_file_and_call_chat_completion_e2e(
mocker: MockerFixture, monkeypatch, llm_router: Router
):
"""
1. Create a file
2. Call a chat completion with the file
3. Assert the file is used in the chat completion
"""
# Create and enable respx mock instance
mock = respx.mock()
mock.start()
try:
from litellm.types.llms.openai import OpenAIFileObject
monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", llm_router)
proxy_logging_object = setup_proxy_logging_object(monkeypatch, llm_router)
# Create a simple test file content
test_file_content = b"test audio content"
test_file = ("test.wav", test_file_content, "audio/wav")
# Mock the file creation response
mock_file_response = OpenAIFileObject(
id="test-file-id",
object="file",
bytes=123,
created_at=1234567890,
filename="test.wav",
purpose="user_data",
status="uploaded",
)
mock_file_response._hidden_params = {"model_id": "gemini-2.0-flash-id"}
mocker.patch.object(llm_router, "acreate_file", return_value=mock_file_response)
# Mock the Gemini API call using respx
mock_gemini_response = {
"candidates": [
{"content": {"parts": [{"text": "This is a test audio file"}]}}
]
}
# Mock the Gemini API endpoint with a more flexible pattern
gemini_route = mock.post(
url__regex=r".*generativelanguage\.googleapis\.com.*"
).mock(
return_value=respx.MockResponse(status_code=200, json=mock_gemini_response),
)
# Print updated mock setup
print("\nAfter Adding Gemini Route:")
print("==========================")
print(f"Number of mocked routes: {len(mock.routes)}")
for route in mock.routes:
print(f"Mocked Route: {route}")
print(f"Pattern: {route.pattern}")
## CREATE FILE
file = client.post(
"/v1/files",
files={"file": test_file},
data={
"purpose": "user_data",
"target_model_names": "gemini-2.0-flash, gpt-3.5-turbo",
},
headers={"Authorization": "Bearer test-key"},
)
print("\nAfter File Creation:")
print("====================")
print(f"File creation status: {file.status_code}")
print(f"Recorded calls so far: {len(mock.calls)}")
for call in mock.calls:
print(f"Call made to: {call.request.method} {call.request.url}")
assert file.status_code == 200
assert file.json()["id"] != "test-file-id" # unified file id used
## USE FILE IN CHAT COMPLETION
try:
completion = client.post(
"/v1/chat/completions",
json={
"model": "gemini-2.0-flash",
"modalities": ["text", "audio"],
"audio": {"voice": "alloy", "format": "wav"},
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What is in this recording?"},
{
"type": "file",
"file": {
"file_id": file.json()["id"],
"filename": "my-test-name",
"format": "audio/wav",
},
},
],
},
],
"drop_params": True,
},
headers={"Authorization": "Bearer test-key"},
)
except Exception as e:
print(f"error: {e}")
print("\nError occurred during chat completion:")
print("=====================================")
print("\nFinal Mock State:")
print("=================")
print(f"Total mocked routes: {len(mock.routes)}")
for route in mock.routes:
print(f"\nMocked Route: {route}")
print(f" Called: {route.called}")
print("\nActual Requests Made:")
print("=====================")
print(f"Total calls recorded: {len(mock.calls)}")
for idx, call in enumerate(mock.calls):
print(f"\nCall {idx + 1}:")
print(f" Method: {call.request.method}")
print(f" URL: {call.request.url}")
print(f" Headers: {dict(call.request.headers)}")
try:
print(f" Body: {call.request.content.decode()}")
except:
print(" Body: <could not decode>")
# Verify Gemini API was called
assert gemini_route.called, "Gemini API was not called"
# Print the call details
print("\nGemini API Call Details:")
print(f"URL: {gemini_route.calls.last.request.url}")
print(f"Method: {gemini_route.calls.last.request.method}")
print(f"Headers: {dict(gemini_route.calls.last.request.headers)}")
print(f"Content: {gemini_route.calls.last.request.content.decode()}")
print(f"Response: {gemini_route.calls.last.response.content.decode()}")
assert "test-file-id" in gemini_route.calls.last.request.content.decode()
finally:
# Stop the mock
mock.stop()

View file

@ -39,7 +39,6 @@ async def test_initialize_scheduled_jobs_credentials(monkeypatch):
with patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config), patch( with patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config), patch(
"litellm.proxy.proxy_server.store_model_in_db", False "litellm.proxy.proxy_server.store_model_in_db", False
): # set store_model_in_db to False ): # set store_model_in_db to False
# Test when store_model_in_db is False # Test when store_model_in_db is False
await ProxyStartupEvent.initialize_scheduled_background_jobs( await ProxyStartupEvent.initialize_scheduled_background_jobs(
general_settings={}, general_settings={},
@ -57,7 +56,6 @@ async def test_initialize_scheduled_jobs_credentials(monkeypatch):
with patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config), patch( with patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config), patch(
"litellm.proxy.proxy_server.store_model_in_db", True "litellm.proxy.proxy_server.store_model_in_db", True
), patch("litellm.proxy.proxy_server.get_secret_bool", return_value=True): ), patch("litellm.proxy.proxy_server.get_secret_bool", return_value=True):
await ProxyStartupEvent.initialize_scheduled_background_jobs( await ProxyStartupEvent.initialize_scheduled_background_jobs(
general_settings={}, general_settings={},
prisma_client=mock_prisma_client, prisma_client=mock_prisma_client,

View file

@ -1116,3 +1116,6 @@ def test_anthropic_thinking_in_assistant_message(model):
response = litellm.completion(**params) response = litellm.completion(**params)
assert response is not None assert response is not None

View file

@ -34,6 +34,7 @@ export const prepareModelAddRequest = async (
} }
// Create a deployment for each mapping // Create a deployment for each mapping
const deployments = [];
for (const mapping of modelMappings) { for (const mapping of modelMappings) {
const litellmParamsObj: Record<string, any> = {}; const litellmParamsObj: Record<string, any> = {};
const modelInfoObj: Record<string, any> = {}; const modelInfoObj: Record<string, any> = {};
@ -142,8 +143,10 @@ export const prepareModelAddRequest = async (
} }
} }
return { litellmParamsObj, modelInfoObj, modelName }; deployments.push({ litellmParamsObj, modelInfoObj, modelName });
} }
return deployments;
} catch (error) { } catch (error) {
message.error("Failed to create model: " + error, 10); message.error("Failed to create model: " + error, 10);
} }
@ -156,22 +159,25 @@ export const handleAddModelSubmit = async (
callback?: () => void, callback?: () => void,
) => { ) => {
try { try {
const result = await prepareModelAddRequest(values, accessToken, form); const deployments = await prepareModelAddRequest(values, accessToken, form);
if (!result) { if (!deployments || deployments.length === 0) {
return; // Exit if preparation failed return; // Exit if preparation failed or no deployments
} }
const { litellmParamsObj, modelInfoObj, modelName } = result; // Create each deployment
for (const deployment of deployments) {
const { litellmParamsObj, modelInfoObj, modelName } = deployment;
const new_model: Model = { const new_model: Model = {
model_name: modelName, model_name: modelName,
litellm_params: litellmParamsObj, litellm_params: litellmParamsObj,
model_info: modelInfoObj, model_info: modelInfoObj,
}; };
const response: any = await modelCreateCall(accessToken, new_model); const response: any = await modelCreateCall(accessToken, new_model);
console.log(`response for model create call: ${response["data"]}`); console.log(`response for model create call: ${response["data"]}`);
}
callback && callback(); callback && callback();
form.resetFields(); form.resetFields();

View file

@ -55,7 +55,7 @@ const ModelConnectionTest: React.FC<ModelConnectionTestProps> = ({
console.log("Result from prepareModelAddRequest:", result); console.log("Result from prepareModelAddRequest:", result);
const { litellmParamsObj, modelInfoObj, modelName: returnedModelName } = result; const { litellmParamsObj, modelInfoObj, modelName: returnedModelName } = result[0];
const response = await testConnectionRequest(accessToken, litellmParamsObj, modelInfoObj?.mode); const response = await testConnectionRequest(accessToken, litellmParamsObj, modelInfoObj?.mode);
if (response.status === "success") { if (response.status === "success") {

View file

@ -13,7 +13,7 @@ import {
TabPanel, TabPanels, DonutChart, TabPanel, TabPanels, DonutChart,
Table, TableHead, TableRow, Table, TableHead, TableRow,
TableHeaderCell, TableBody, TableCell, TableHeaderCell, TableBody, TableCell,
Subtitle Subtitle, DateRangePicker, DateRangePickerValue
} from "@tremor/react"; } from "@tremor/react";
import { AreaChart } from "@tremor/react"; import { AreaChart } from "@tremor/react";
@ -41,6 +41,12 @@ const NewUsagePage: React.FC<NewUsagePageProps> = ({
metadata: any; metadata: any;
}>({ results: [], metadata: {} }); }>({ results: [], metadata: {} });
// Add date range state
const [dateValue, setDateValue] = useState<DateRangePickerValue>({
from: new Date(Date.now() - 28 * 24 * 60 * 60 * 1000),
to: new Date(),
});
// Derived states from userSpendData // Derived states from userSpendData
const totalSpend = userSpendData.metadata?.total_spend || 0; const totalSpend = userSpendData.metadata?.total_spend || 0;
@ -168,22 +174,34 @@ const NewUsagePage: React.FC<NewUsagePageProps> = ({
}; };
const fetchUserSpendData = async () => { const fetchUserSpendData = async () => {
if (!accessToken) return; if (!accessToken || !dateValue.from || !dateValue.to) return;
const startTime = new Date(Date.now() - 28 * 24 * 60 * 60 * 1000); const startTime = dateValue.from;
const endTime = new Date(); const endTime = dateValue.to;
const data = await userDailyActivityCall(accessToken, startTime, endTime); const data = await userDailyActivityCall(accessToken, startTime, endTime);
setUserSpendData(data); setUserSpendData(data);
}; };
useEffect(() => { useEffect(() => {
fetchUserSpendData(); fetchUserSpendData();
}, [accessToken]); }, [accessToken, dateValue]);
const modelMetrics = processActivityData(userSpendData); const modelMetrics = processActivityData(userSpendData);
return ( return (
<div style={{ width: "100%" }} className="p-8"> <div style={{ width: "100%" }} className="p-8">
<Text>Experimental Usage page, using new `/user/daily/activity` endpoint.</Text> <Text>Experimental Usage page, using new `/user/daily/activity` endpoint.</Text>
<Grid numItems={2} className="gap-2 w-full mb-4">
<Col>
<Text>Select Time Range</Text>
<DateRangePicker
enableSelect={true}
value={dateValue}
onValueChange={(value) => {
setDateValue(value);
}}
/>
</Col>
</Grid>
<TabGroup> <TabGroup>
<TabList variant="solid" className="mt-1"> <TabList variant="solid" className="mt-1">
<Tab>Cost</Tab> <Tab>Cost</Tab>

View file

@ -175,6 +175,14 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
try { try {
if (!accessToken) return; if (!accessToken) return;
let parsedMetadata = {};
try {
parsedMetadata = values.metadata ? JSON.parse(values.metadata) : {};
} catch (e) {
message.error("Invalid JSON in metadata field");
return;
}
const updateData = { const updateData = {
team_id: teamId, team_id: teamId,
team_alias: values.team_alias, team_alias: values.team_alias,
@ -184,7 +192,7 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
max_budget: values.max_budget, max_budget: values.max_budget,
budget_duration: values.budget_duration, budget_duration: values.budget_duration,
metadata: { metadata: {
...values.metadata, ...parsedMetadata,
guardrails: values.guardrails || [] guardrails: values.guardrails || []
} }
}; };