Compare commits

..

No commits in common. "main" and "v1.67.3.dev1" have entirely different histories.

17 changed files with 42 additions and 297 deletions

View file

@ -108,13 +108,7 @@ class ProxyBaseLLMRequestProcessing:
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
proxy_logging_obj: ProxyLogging, proxy_logging_obj: ProxyLogging,
proxy_config: ProxyConfig, proxy_config: ProxyConfig,
route_type: Literal[ route_type: Literal["acompletion", "aresponses", "_arealtime"],
"acompletion",
"aresponses",
"_arealtime",
"aget_responses",
"adelete_responses",
],
version: Optional[str] = None, version: Optional[str] = None,
user_model: Optional[str] = None, user_model: Optional[str] = None,
user_temperature: Optional[float] = None, user_temperature: Optional[float] = None,
@ -184,13 +178,7 @@ class ProxyBaseLLMRequestProcessing:
request: Request, request: Request,
fastapi_response: Response, fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
route_type: Literal[ route_type: Literal["acompletion", "aresponses", "_arealtime"],
"acompletion",
"aresponses",
"_arealtime",
"aget_responses",
"adelete_responses",
],
proxy_logging_obj: ProxyLogging, proxy_logging_obj: ProxyLogging,
general_settings: dict, general_settings: dict,
proxy_config: ProxyConfig, proxy_config: ProxyConfig,

View file

@ -1,5 +1,5 @@
import json import json
from typing import Any, Dict, List, Optional from typing import Dict, List, Optional
import orjson import orjson
from fastapi import Request, UploadFile, status from fastapi import Request, UploadFile, status
@ -147,10 +147,10 @@ def check_file_size_under_limit(
if llm_router is not None and request_data["model"] in router_model_names: if llm_router is not None and request_data["model"] in router_model_names:
try: try:
deployment: Optional[Deployment] = ( deployment: Optional[
llm_router.get_deployment_by_model_group_name( Deployment
model_group_name=request_data["model"] ] = llm_router.get_deployment_by_model_group_name(
) model_group_name=request_data["model"]
) )
if ( if (
deployment deployment
@ -185,23 +185,3 @@ def check_file_size_under_limit(
) )
return True return True
async def get_form_data(request: Request) -> Dict[str, Any]:
"""
Read form data from request
Handles when OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
"""
form = await request.form()
form_data = dict(form)
parsed_form_data: dict[str, Any] = {}
for key, value in form_data.items():
# OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
if key.endswith("[]"):
clean_key = key[:-2]
parsed_form_data.setdefault(clean_key, []).append(value)
else:
parsed_form_data[key] = value
return parsed_form_data

View file

@ -1,8 +1,16 @@
model_list: model_list:
- model_name: openai/* - model_name: azure-computer-use-preview
litellm_params: litellm_params:
model: openai/* model: azure/computer-use-preview
api_key: os.environ/OPENAI_API_KEY api_key: mock-api-key
api_version: mock-api-version
api_base: https://mock-endpoint.openai.azure.com
- model_name: azure-computer-use-preview
litellm_params:
model: azure/computer-use-preview-2
api_key: mock-api-key-2
api_version: mock-api-version-2
api_base: https://mock-endpoint-2.openai.azure.com
router_settings: router_settings:
optional_pre_call_checks: ["responses_api_deployment_check"] optional_pre_call_checks: ["responses_api_deployment_check"]

View file

@ -179,7 +179,6 @@ from litellm.proxy.common_utils.html_forms.ui_login import html_form
from litellm.proxy.common_utils.http_parsing_utils import ( from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body, _read_request_body,
check_file_size_under_limit, check_file_size_under_limit,
get_form_data,
) )
from litellm.proxy.common_utils.load_config_utils import ( from litellm.proxy.common_utils.load_config_utils import (
get_config_file_contents_from_gcs, get_config_file_contents_from_gcs,
@ -805,9 +804,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
dual_cache=user_api_key_cache dual_cache=user_api_key_cache
) )
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter) litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
redis_usage_cache: Optional[RedisCache] = ( redis_usage_cache: Optional[
None # redis cache used for tracking spend, tpm/rpm limits RedisCache
) ] = None # redis cache used for tracking spend, tpm/rpm limits
user_custom_auth = None user_custom_auth = None
user_custom_key_generate = None user_custom_key_generate = None
user_custom_sso = None user_custom_sso = None
@ -1133,9 +1132,9 @@ async def update_cache( # noqa: PLR0915
_id = "team_id:{}".format(team_id) _id = "team_id:{}".format(team_id)
try: try:
# Fetch the existing cost for the given user # Fetch the existing cost for the given user
existing_spend_obj: Optional[LiteLLM_TeamTable] = ( existing_spend_obj: Optional[
await user_api_key_cache.async_get_cache(key=_id) LiteLLM_TeamTable
) ] = await user_api_key_cache.async_get_cache(key=_id)
if existing_spend_obj is None: if existing_spend_obj is None:
# do nothing if team not in api key cache # do nothing if team not in api key cache
return return
@ -2807,9 +2806,9 @@ async def initialize( # noqa: PLR0915
user_api_base = api_base user_api_base = api_base
dynamic_config[user_model]["api_base"] = api_base dynamic_config[user_model]["api_base"] = api_base
if api_version: if api_version:
os.environ["AZURE_API_VERSION"] = ( os.environ[
api_version # set this for azure - litellm can read this from the env "AZURE_API_VERSION"
) ] = api_version # set this for azure - litellm can read this from the env
if max_tokens: # model-specific param if max_tokens: # model-specific param
dynamic_config[user_model]["max_tokens"] = max_tokens dynamic_config[user_model]["max_tokens"] = max_tokens
if temperature: # model-specific param if temperature: # model-specific param
@ -4121,7 +4120,7 @@ async def audio_transcriptions(
data: Dict = {} data: Dict = {}
try: try:
# Use orjson to parse JSON data, orjson speeds up requests significantly # Use orjson to parse JSON data, orjson speeds up requests significantly
form_data = await get_form_data(request) form_data = await request.form()
data = {key: value for key, value in form_data.items() if key != "file"} data = {key: value for key, value in form_data.items() if key != "file"}
# Include original request and headers in the data # Include original request and headers in the data
@ -7759,9 +7758,9 @@ async def get_config_list(
hasattr(sub_field_info, "description") hasattr(sub_field_info, "description")
and sub_field_info.description is not None and sub_field_info.description is not None
): ):
nested_fields[idx].field_description = ( nested_fields[
sub_field_info.description idx
) ].field_description = sub_field_info.description
idx += 1 idx += 1
_stored_in_db = None _stored_in_db = None

View file

@ -106,50 +106,8 @@ async def get_response(
-H "Authorization: Bearer sk-1234" -H "Authorization: Bearer sk-1234"
``` ```
""" """
from litellm.proxy.proxy_server import ( # TODO: Implement response retrieval logic
_read_request_body, pass
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
select_data_generator,
user_api_base,
user_max_tokens,
user_model,
user_request_timeout,
user_temperature,
version,
)
data = await _read_request_body(request=request)
data["response_id"] = response_id
processor = ProxyBaseLLMRequestProcessing(data=data)
try:
return await processor.base_process_llm_request(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
route_type="aget_responses",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=select_data_generator,
model=None,
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
except Exception as e:
raise await processor._handle_llm_api_exception(
e=e,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
version=version,
)
@router.delete( @router.delete(
@ -178,50 +136,8 @@ async def delete_response(
-H "Authorization: Bearer sk-1234" -H "Authorization: Bearer sk-1234"
``` ```
""" """
from litellm.proxy.proxy_server import ( # TODO: Implement response deletion logic
_read_request_body, pass
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
select_data_generator,
user_api_base,
user_max_tokens,
user_model,
user_request_timeout,
user_temperature,
version,
)
data = await _read_request_body(request=request)
data["response_id"] = response_id
processor = ProxyBaseLLMRequestProcessing(data=data)
try:
return await processor.base_process_llm_request(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
route_type="adelete_responses",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=select_data_generator,
model=None,
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
except Exception as e:
raise await processor._handle_llm_api_exception(
e=e,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
version=version,
)
@router.get( @router.get(

View file

@ -47,8 +47,6 @@ async def route_request(
"amoderation", "amoderation",
"arerank", "arerank",
"aresponses", "aresponses",
"aget_responses",
"adelete_responses",
"_arealtime", # private function for realtime API "_arealtime", # private function for realtime API
], ],
): ):

View file

@ -176,16 +176,6 @@ class ResponsesAPIRequestUtils:
response_id=response_id, response_id=response_id,
) )
@staticmethod
def get_model_id_from_response_id(response_id: Optional[str]) -> Optional[str]:
"""Get the model_id from the response_id"""
if response_id is None:
return None
decoded_response_id = (
ResponsesAPIRequestUtils._decode_responses_api_response_id(response_id)
)
return decoded_response_id.get("model_id") or None
class ResponseAPILoggingUtils: class ResponseAPILoggingUtils:
@staticmethod @staticmethod

View file

@ -739,12 +739,6 @@ class Router:
litellm.afile_content, call_type="afile_content" litellm.afile_content, call_type="afile_content"
) )
self.responses = self.factory_function(litellm.responses, call_type="responses") self.responses = self.factory_function(litellm.responses, call_type="responses")
self.aget_responses = self.factory_function(
litellm.aget_responses, call_type="aget_responses"
)
self.adelete_responses = self.factory_function(
litellm.adelete_responses, call_type="adelete_responses"
)
def validate_fallbacks(self, fallback_param: Optional[List]): def validate_fallbacks(self, fallback_param: Optional[List]):
""" """
@ -3087,8 +3081,6 @@ class Router:
"anthropic_messages", "anthropic_messages",
"aresponses", "aresponses",
"responses", "responses",
"aget_responses",
"adelete_responses",
"afile_delete", "afile_delete",
"afile_content", "afile_content",
] = "assistants", ] = "assistants",
@ -3143,11 +3135,6 @@ class Router:
original_function=original_function, original_function=original_function,
**kwargs, **kwargs,
) )
elif call_type in ("aget_responses", "adelete_responses"):
return await self._init_responses_api_endpoints(
original_function=original_function,
**kwargs,
)
elif call_type in ("afile_delete", "afile_content"): elif call_type in ("afile_delete", "afile_content"):
return await self._ageneric_api_call_with_fallbacks( return await self._ageneric_api_call_with_fallbacks(
original_function=original_function, original_function=original_function,
@ -3158,28 +3145,6 @@ class Router:
return async_wrapper return async_wrapper
async def _init_responses_api_endpoints(
self,
original_function: Callable,
**kwargs,
):
"""
Initialize the Responses API endpoints on the router.
GET, DELETE Responses API Requests encode the model_id in the response_id, this function decodes the response_id and sets the model to the model_id.
"""
from litellm.responses.utils import ResponsesAPIRequestUtils
model_id = ResponsesAPIRequestUtils.get_model_id_from_response_id(
kwargs.get("response_id")
)
if model_id is not None:
kwargs["model"] = model_id
return await self._ageneric_api_call_with_fallbacks(
original_function=original_function,
**kwargs,
)
async def _pass_through_assistants_endpoint_factory( async def _pass_through_assistants_endpoint_factory(
self, self,
original_function: Callable, original_function: Callable,

View file

@ -7058,17 +7058,6 @@
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models",
"supports_tool_choice": true "supports_tool_choice": true
}, },
"command-a-03-2025": {
"max_tokens": 8000,
"max_input_tokens": 256000,
"max_output_tokens": 8000,
"input_cost_per_token": 0.0000025,
"output_cost_per_token": 0.00001,
"litellm_provider": "cohere_chat",
"mode": "chat",
"supports_function_calling": true,
"supports_tool_choice": true
},
"command-r": { "command-r": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 128000, "max_input_tokens": 128000,

View file

@ -18,7 +18,6 @@ from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body, _read_request_body,
_safe_get_request_parsed_body, _safe_get_request_parsed_body,
_safe_set_request_parsed_body, _safe_set_request_parsed_body,
get_form_data,
) )
@ -148,53 +147,3 @@ async def test_circular_reference_handling():
assert ( assert (
"proxy_server_request" not in result2 "proxy_server_request" not in result2
) # This will pass, showing the cache pollution ) # This will pass, showing the cache pollution
@pytest.mark.asyncio
async def test_get_form_data():
"""
Test that get_form_data correctly handles form data with array notation.
Tests audio transcription parameters as a specific example.
"""
# Create a mock request with transcription form data
mock_request = MagicMock()
# Create mock form data with array notation for timestamp_granularities
mock_form_data = {
"file": "file_object", # In a real request this would be an UploadFile
"model": "gpt-4o-transcribe",
"include[]": "logprobs", # Array notation
"language": "en",
"prompt": "Transcribe this audio file",
"response_format": "json",
"stream": "false",
"temperature": "0.2",
"timestamp_granularities[]": "word", # First array item
"timestamp_granularities[]": "segment", # Second array item (would overwrite in dict, but handled by the function)
}
# Mock the form method to return the test data
mock_request.form = AsyncMock(return_value=mock_form_data)
# Call the function being tested
result = await get_form_data(mock_request)
# Verify regular form fields are preserved
assert result["file"] == "file_object"
assert result["model"] == "gpt-4o-transcribe"
assert result["language"] == "en"
assert result["prompt"] == "Transcribe this audio file"
assert result["response_format"] == "json"
assert result["stream"] == "false"
assert result["temperature"] == "0.2"
# Verify array fields are correctly parsed
assert "include" in result
assert isinstance(result["include"], list)
assert "logprobs" in result["include"]
assert "timestamp_granularities" in result
assert isinstance(result["timestamp_granularities"], list)
# Note: In a real MultiDict, both values would be present
# But in our mock dictionary the second value overwrites the first
assert "segment" in result["timestamp_granularities"]

View file

@ -947,8 +947,6 @@ class BaseLLMChatTest(ABC):
second_response.choices[0].message.content is not None second_response.choices[0].message.content is not None
or second_response.choices[0].message.tool_calls is not None or second_response.choices[0].message.tool_calls is not None
) )
except litellm.ServiceUnavailableError:
pytest.skip("Model is overloaded")
except litellm.InternalServerError: except litellm.InternalServerError:
pytest.skip("Model is overloaded") pytest.skip("Model is overloaded")
except litellm.RateLimitError: except litellm.RateLimitError:

View file

@ -73,31 +73,15 @@ def validate_stream_chunk(chunk):
def test_basic_response(): def test_basic_response():
client = get_test_client() client = get_test_client()
response = client.responses.create( response = client.responses.create(
model="gpt-4.0", input="just respond with the word 'ping'" model="gpt-4o", input="just respond with the word 'ping'"
) )
print("basic response=", response) print("basic response=", response)
# get the response
response = client.responses.retrieve(response.id)
print("GET response=", response)
# delete the response
delete_response = client.responses.delete(response.id)
print("DELETE response=", delete_response)
# try getting the response again, we should not get it back
get_response = client.responses.retrieve(response.id)
print("GET response after delete=", get_response)
with pytest.raises(Exception):
get_response = client.responses.retrieve(response.id)
def test_streaming_response(): def test_streaming_response():
client = get_test_client() client = get_test_client()
stream = client.responses.create( stream = client.responses.create(
model="gpt-4.0", input="just respond with the word 'ping'", stream=True model="gpt-4o", input="just respond with the word 'ping'", stream=True
) )
collected_chunks = [] collected_chunks = []
@ -120,5 +104,5 @@ def test_bad_request_bad_param_error():
with pytest.raises(BadRequestError): with pytest.raises(BadRequestError):
# Trigger error with invalid model name # Trigger error with invalid model name
client.responses.create( client.responses.create(
model="gpt-4.0", input="This should fail", temperature=2000 model="gpt-4o", input="This should fail", temperature=2000
) )

View file

@ -1157,14 +1157,3 @@ def test_cached_get_model_group_info(model_list):
# Verify the cache info shows hits # Verify the cache info shows hits
cache_info = router._cached_get_model_group_info.cache_info() cache_info = router._cached_get_model_group_info.cache_info()
assert cache_info.hits > 0 # Should have at least one cache hit assert cache_info.hits > 0 # Should have at least one cache hit
def test_init_responses_api_endpoints(model_list):
"""Test if the '_init_responses_api_endpoints' function is working correctly"""
from typing import Callable
router = Router(model_list=model_list)
assert router.aget_responses is not None
assert isinstance(router.aget_responses, Callable)
assert router.adelete_responses is not None
assert isinstance(router.adelete_responses, Callable)

View file

@ -151,7 +151,7 @@ export default function CreateKeyPage() {
if (redirectToLogin) { if (redirectToLogin) {
window.location.href = (proxyBaseUrl || "") + "/sso/key/generate" window.location.href = (proxyBaseUrl || "") + "/sso/key/generate"
} }
}, [redirectToLogin]) }, [token, authLoading])
useEffect(() => { useEffect(() => {
if (!token) { if (!token) {
@ -223,7 +223,7 @@ export default function CreateKeyPage() {
} }
}, [accessToken, userID, userRole]); }, [accessToken, userID, userRole]);
if (authLoading || redirectToLogin) { if (authLoading) {
return <LoadingScreen /> return <LoadingScreen />
} }

View file

@ -450,8 +450,6 @@ export function AllKeysTable({
columns={columns.filter(col => col.id !== 'expander') as any} columns={columns.filter(col => col.id !== 'expander') as any}
data={filteredKeys as any} data={filteredKeys as any}
isLoading={isLoading} isLoading={isLoading}
loadingMessage="🚅 Loading keys..."
noDataMessage="No keys found"
getRowCanExpand={() => false} getRowCanExpand={() => false}
renderSubComponent={() => <></>} renderSubComponent={() => <></>}
/> />

View file

@ -26,8 +26,6 @@ function DataTableWrapper({
isLoading={isLoading} isLoading={isLoading}
renderSubComponent={renderSubComponent} renderSubComponent={renderSubComponent}
getRowCanExpand={getRowCanExpand} getRowCanExpand={getRowCanExpand}
loadingMessage="🚅 Loading tools..."
noDataMessage="No tools found"
/> />
); );
} }

View file

@ -26,8 +26,6 @@ interface DataTableProps<TData, TValue> {
expandedRequestId?: string | null; expandedRequestId?: string | null;
onRowExpand?: (requestId: string | null) => void; onRowExpand?: (requestId: string | null) => void;
setSelectedKeyIdInfoView?: (keyId: string | null) => void; setSelectedKeyIdInfoView?: (keyId: string | null) => void;
loadingMessage?: string;
noDataMessage?: string;
} }
export function DataTable<TData extends { request_id: string }, TValue>({ export function DataTable<TData extends { request_id: string }, TValue>({
@ -38,8 +36,6 @@ export function DataTable<TData extends { request_id: string }, TValue>({
isLoading = false, isLoading = false,
expandedRequestId, expandedRequestId,
onRowExpand, onRowExpand,
loadingMessage = "🚅 Loading logs...",
noDataMessage = "No logs found",
}: DataTableProps<TData, TValue>) { }: DataTableProps<TData, TValue>) {
const table = useReactTable({ const table = useReactTable({
data, data,
@ -118,7 +114,7 @@ export function DataTable<TData extends { request_id: string }, TValue>({
<TableRow> <TableRow>
<TableCell colSpan={columns.length} className="h-8 text-center"> <TableCell colSpan={columns.length} className="h-8 text-center">
<div className="text-center text-gray-500"> <div className="text-center text-gray-500">
<p>{loadingMessage}</p> <p>🚅 Loading logs...</p>
</div> </div>
</TableCell> </TableCell>
</TableRow> </TableRow>
@ -151,7 +147,7 @@ export function DataTable<TData extends { request_id: string }, TValue>({
: <TableRow> : <TableRow>
<TableCell colSpan={columns.length} className="h-8 text-center"> <TableCell colSpan={columns.length} className="h-8 text-center">
<div className="text-center text-gray-500"> <div className="text-center text-gray-500">
<p>{noDataMessage}</p> <p>No logs found</p>
</div> </div>
</TableCell> </TableCell>
</TableRow> </TableRow>