Merge pull request #9183 from BerriAI/litellm_router_responses_api_2

[Feat] - Add Responses API on LiteLLM Proxy
This commit is contained in:
Ishaan Jaff 2025-03-12 21:28:16 -07:00 committed by GitHub
commit 1d31e25816
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
44 changed files with 1165 additions and 401 deletions

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# [BETA] `/v1/messages`
# /v1/messages [BETA]
LiteLLM provides a BETA endpoint in the spec of Anthropic's `/v1/messages` endpoint.

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Assistants API
# /assistants
Covers Threads, Messages, Assistants.

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# [BETA] Batches API
# /batches
Covers Batches, Files

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Embeddings
# /embeddings
## Quick Start
```python

View file

@ -2,7 +2,7 @@
import TabItem from '@theme/TabItem';
import Tabs from '@theme/Tabs';
# Files API
# /files
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# [Beta] Fine-tuning API
# /fine_tuning
:::info

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Moderation
# /moderations
### Usage

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Realtime Endpoints
# /realtime
Use this to loadbalance across Azure + OpenAI.

View file

@ -1,4 +1,4 @@
# Rerank
# /rerank
:::tip

View file

@ -0,0 +1,117 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# /responses
LiteLLM provides a BETA endpoint in the spec of [OpenAI's `/responses` API](https://platform.openai.com/docs/api-reference/responses)
| Feature | Supported | Notes |
|---------|-----------|--------|
| Cost Tracking | ✅ | Works with all supported models |
| Logging | ✅ | Works across all integrations |
| End-user Tracking | ✅ | |
| Streaming | ✅ | |
| Fallbacks | ✅ | Works between supported models |
| Loadbalancing | ✅ | Works between supported models |
| Supported LiteLLM Versions | 1.63.8+ | |
| Supported LLM providers | `openai` | |
## Usage
## Create a model response
<Tabs>
<TabItem value="litellm-sdk" label="LiteLLM SDK">
#### Non-streaming
```python
import litellm
# Non-streaming response
response = litellm.responses(
model="gpt-4o",
input="Tell me a three sentence bedtime story about a unicorn.",
max_output_tokens=100
)
print(response)
```
#### Streaming
```python
import litellm
# Streaming response
response = litellm.responses(
model="gpt-4o",
input="Tell me a three sentence bedtime story about a unicorn.",
stream=True
)
for event in response:
print(event)
```
</TabItem>
<TabItem value="proxy" label="OpenAI SDK with LiteLLM Proxy">
First, add this to your litellm proxy config.yaml:
```yaml
model_list:
- model_name: gpt-4o
litellm_params:
model: openai/gpt-4
api_key: os.environ/OPENAI_API_KEY
```
Start your LiteLLM proxy:
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
Then use the OpenAI SDK pointed to your proxy:
#### Non-streaming
```python
from openai import OpenAI
# Initialize client with your proxy URL
client = OpenAI(
base_url="http://localhost:4000", # Your proxy URL
api_key="your-api-key" # Your proxy API key
)
# Non-streaming response
response = client.responses.create(
model="gpt-4o",
input="Tell me a three sentence bedtime story about a unicorn."
)
print(response)
```
#### Streaming
```python
from openai import OpenAI
# Initialize client with your proxy URL
client = OpenAI(
base_url="http://localhost:4000", # Your proxy URL
api_key="your-api-key" # Your proxy API key
)
# Streaming response
response = client.responses.create(
model="gpt-4o",
input="Tell me a three sentence bedtime story about a unicorn.",
stream=True
)
for event in response:
print(event)
```
</TabItem>
</Tabs>

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Text Completion
# /completions
### Usage
<Tabs>

View file

@ -273,7 +273,7 @@ const sidebars = {
items: [
{
type: "category",
label: "Chat",
label: "/chat/completions",
link: {
type: "generated-index",
title: "Chat Completions",
@ -286,12 +286,13 @@ const sidebars = {
"completion/usage",
],
},
"response_api",
"text_completion",
"embedding/supported_embedding",
"anthropic_unified",
{
type: "category",
label: "Image",
label: "/images",
items: [
"image_generation",
"image_variations",
@ -299,7 +300,7 @@ const sidebars = {
},
{
type: "category",
label: "Audio",
label: "/audio",
"items": [
"audio_transcription",
"text_to_speech",

View file

@ -163,7 +163,7 @@ class AporiaGuardrail(CustomGuardrail):
pass
async def async_moderation_hook( ### 👈 KEY CHANGE ###
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
@ -173,6 +173,7 @@ class AporiaGuardrail(CustomGuardrail):
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
):
from litellm.proxy.common_utils.callback_utils import (

View file

@ -94,6 +94,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
):
"""

View file

@ -107,6 +107,7 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
):
"""

View file

@ -126,6 +126,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
):
"""

View file

@ -31,7 +31,7 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
#### CALL HOOKS - proxy only ####
async def async_moderation_hook( ### 👈 KEY CHANGE ###
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
@ -41,6 +41,7 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
):
text = ""

View file

@ -239,6 +239,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
) -> Any:
pass

View file

@ -14,6 +14,7 @@ import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.proxy.utils import ProxyLogging
@ -89,7 +90,6 @@ async def anthropic_response( # noqa: PLR0915
"""
from litellm.proxy.proxy_server import (
general_settings,
get_custom_headers,
llm_router,
proxy_config,
proxy_logging_obj,
@ -205,7 +205,7 @@ async def anthropic_response( # noqa: PLR0915
verbose_proxy_logger.debug("final response: %s", response)
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,

View file

@ -18,6 +18,7 @@ from litellm.batches.main import (
)
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.common_utils.openai_endpoint_utils import (
get_custom_llm_provider_from_request_body,
@ -69,7 +70,6 @@ async def create_batch(
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
llm_router,
proxy_config,
proxy_logging_obj,
@ -137,7 +137,7 @@ async def create_batch(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -201,7 +201,6 @@ async def retrieve_batch(
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
llm_router,
proxy_config,
proxy_logging_obj,
@ -266,7 +265,7 @@ async def retrieve_batch(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -326,11 +325,7 @@ async def list_batches(
```
"""
from litellm.proxy.proxy_server import (
get_custom_headers,
proxy_logging_obj,
version,
)
from litellm.proxy.proxy_server import proxy_logging_obj, version
verbose_proxy_logger.debug("GET /v1/batches after={} limit={}".format(after, limit))
try:
@ -352,7 +347,7 @@ async def list_batches(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -417,7 +412,6 @@ async def cancel_batch(
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
proxy_config,
proxy_logging_obj,
version,
@ -463,7 +457,7 @@ async def cancel_batch(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,

View file

@ -0,0 +1,356 @@
import asyncio
import json
import uuid
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
import httpx
from fastapi import HTTPException, Request, status
from fastapi.responses import Response, StreamingResponse
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy._types import ProxyException, UserAPIKeyAuth
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
from litellm.proxy.common_utils.callback_utils import (
get_logging_caching_headers,
get_remaining_tokens_and_requests_from_request_data,
)
from litellm.proxy.route_llm_request import route_request
from litellm.proxy.utils import ProxyLogging
from litellm.router import Router
if TYPE_CHECKING:
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
ProxyConfig = _ProxyConfig
else:
ProxyConfig = Any
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
class ProxyBaseLLMRequestProcessing:
def __init__(self, data: dict):
self.data = data
@staticmethod
def get_custom_headers(
*,
user_api_key_dict: UserAPIKeyAuth,
call_id: Optional[str] = None,
model_id: Optional[str] = None,
cache_key: Optional[str] = None,
api_base: Optional[str] = None,
version: Optional[str] = None,
model_region: Optional[str] = None,
response_cost: Optional[Union[float, str]] = None,
hidden_params: Optional[dict] = None,
fastest_response_batch_completion: Optional[bool] = None,
request_data: Optional[dict] = {},
timeout: Optional[Union[float, int, httpx.Timeout]] = None,
**kwargs,
) -> dict:
exclude_values = {"", None, "None"}
hidden_params = hidden_params or {}
headers = {
"x-litellm-call-id": call_id,
"x-litellm-model-id": model_id,
"x-litellm-cache-key": cache_key,
"x-litellm-model-api-base": api_base,
"x-litellm-version": version,
"x-litellm-model-region": model_region,
"x-litellm-response-cost": str(response_cost),
"x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit),
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
"x-litellm-key-max-budget": str(user_api_key_dict.max_budget),
"x-litellm-key-spend": str(user_api_key_dict.spend),
"x-litellm-response-duration-ms": str(
hidden_params.get("_response_ms", None)
),
"x-litellm-overhead-duration-ms": str(
hidden_params.get("litellm_overhead_time_ms", None)
),
"x-litellm-fastest_response_batch_completion": (
str(fastest_response_batch_completion)
if fastest_response_batch_completion is not None
else None
),
"x-litellm-timeout": str(timeout) if timeout is not None else None,
**{k: str(v) for k, v in kwargs.items()},
}
if request_data:
remaining_tokens_header = (
get_remaining_tokens_and_requests_from_request_data(request_data)
)
headers.update(remaining_tokens_header)
logging_caching_headers = get_logging_caching_headers(request_data)
if logging_caching_headers:
headers.update(logging_caching_headers)
try:
return {
key: str(value)
for key, value in headers.items()
if value not in exclude_values
}
except Exception as e:
verbose_proxy_logger.error(f"Error setting custom headers: {e}")
return {}
async def base_process_llm_request(
self,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth,
route_type: Literal["acompletion", "aresponses"],
proxy_logging_obj: ProxyLogging,
general_settings: dict,
proxy_config: ProxyConfig,
select_data_generator: Callable,
llm_router: Optional[Router] = None,
model: Optional[str] = None,
user_model: Optional[str] = None,
user_temperature: Optional[float] = None,
user_request_timeout: Optional[float] = None,
user_max_tokens: Optional[int] = None,
user_api_base: Optional[str] = None,
version: Optional[str] = None,
) -> Any:
"""
Common request processing logic for both chat completions and responses API endpoints
"""
verbose_proxy_logger.debug(
"Request received by LiteLLM:\n{}".format(json.dumps(self.data, indent=4)),
)
self.data = await add_litellm_data_to_request(
data=self.data,
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
self.data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or model # for azure deployments
or self.data.get("model", None) # default passed in http request
)
# override with user settings, these are params passed via cli
if user_temperature:
self.data["temperature"] = user_temperature
if user_request_timeout:
self.data["request_timeout"] = user_request_timeout
if user_max_tokens:
self.data["max_tokens"] = user_max_tokens
if user_api_base:
self.data["api_base"] = user_api_base
### MODEL ALIAS MAPPING ###
# check if model name in model alias map
# get the actual model name
if (
isinstance(self.data["model"], str)
and self.data["model"] in litellm.model_alias_map
):
self.data["model"] = litellm.model_alias_map[self.data["model"]]
### CALL HOOKS ### - modify/reject incoming data before calling the model
self.data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=self.data, call_type="completion"
)
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
## IMPORTANT Note: - initialize this before running pre-call checks. Ensures we log rejected requests to langfuse.
self.data["litellm_call_id"] = request.headers.get(
"x-litellm-call-id", str(uuid.uuid4())
)
logging_obj, self.data = litellm.utils.function_setup(
original_function=route_type,
rules_obj=litellm.utils.Rules(),
start_time=datetime.now(),
**self.data,
)
self.data["litellm_logging_obj"] = logging_obj
tasks = []
tasks.append(
proxy_logging_obj.during_call_hook(
data=self.data,
user_api_key_dict=user_api_key_dict,
call_type=ProxyBaseLLMRequestProcessing._get_pre_call_type(
route_type=route_type
),
)
)
### ROUTE THE REQUEST ###
# Do not change this - it should be a constant time fetch - ALWAYS
llm_call = await route_request(
data=self.data,
route_type=route_type,
llm_router=llm_router,
user_model=user_model,
)
tasks.append(llm_call)
# wait for call to end
llm_responses = asyncio.gather(
*tasks
) # run the moderation check in parallel to the actual llm api call
responses = await llm_responses
response = responses[1]
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
fastest_response_batch_completion = hidden_params.get(
"fastest_response_batch_completion", None
)
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
# Post Call Processing
if llm_router is not None:
self.data["deployment"] = llm_router.get_deployment(model_id=model_id)
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=self.data.get("litellm_call_id", ""), status="success"
)
)
if (
"stream" in self.data and self.data["stream"] is True
): # use generate_responses to stream responses
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
call_id=logging_obj.litellm_call_id,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
request_data=self.data,
hidden_params=hidden_params,
**additional_headers,
)
selected_data_generator = select_data_generator(
response=response,
user_api_key_dict=user_api_key_dict,
request_data=self.data,
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
headers=custom_headers,
)
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
data=self.data, user_api_key_dict=user_api_key_dict, response=response
)
hidden_params = (
getattr(response, "_hidden_params", {}) or {}
) # get any updated response headers
additional_headers = hidden_params.get("additional_headers", {}) or {}
fastapi_response.headers.update(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
call_id=logging_obj.litellm_call_id,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
request_data=self.data,
hidden_params=hidden_params,
**additional_headers,
)
)
await check_response_size_is_safe(response=response)
return response
async def _handle_llm_api_exception(
self,
e: Exception,
user_api_key_dict: UserAPIKeyAuth,
proxy_logging_obj: ProxyLogging,
version: Optional[str] = None,
):
"""Raises ProxyException (OpenAI API compatible) if an exception is raised"""
verbose_proxy_logger.exception(
f"litellm.proxy.proxy_server._handle_llm_api_exception(): Exception occured - {str(e)}"
)
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=self.data,
)
litellm_debug_info = getattr(e, "litellm_debug_info", "")
verbose_proxy_logger.debug(
"\033[1;31mAn error occurred: %s %s\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`",
e,
litellm_debug_info,
)
timeout = getattr(
e, "timeout", None
) # returns the timeout set by the wrapper. Used for testing if model-specific timeout are set correctly
_litellm_logging_obj: Optional[LiteLLMLoggingObj] = self.data.get(
"litellm_logging_obj", None
)
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
call_id=(
_litellm_logging_obj.litellm_call_id if _litellm_logging_obj else None
),
version=version,
response_cost=0,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=self.data,
timeout=timeout,
)
headers = getattr(e, "headers", {}) or {}
headers.update(custom_headers)
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", str(e)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
headers=headers,
)
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
openai_code=getattr(e, "code", None),
code=getattr(e, "status_code", 500),
headers=headers,
)
@staticmethod
def _get_pre_call_type(
route_type: Literal["acompletion", "aresponses"]
) -> Literal["completion", "responses"]:
if route_type == "acompletion":
return "completion"
elif route_type == "aresponses":
return "responses"

View file

@ -61,6 +61,7 @@ class MyCustomHandler(
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
):
pass

View file

@ -66,6 +66,7 @@ class myCustomGuardrail(CustomGuardrail):
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
):
"""

View file

@ -15,6 +15,7 @@ import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
from litellm.proxy.utils import handle_exception_on_proxy
router = APIRouter()
@ -97,7 +98,6 @@ async def create_fine_tuning_job(
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
premium_user,
proxy_config,
proxy_logging_obj,
@ -151,7 +151,7 @@ async def create_fine_tuning_job(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -205,7 +205,6 @@ async def retrieve_fine_tuning_job(
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
premium_user,
proxy_config,
proxy_logging_obj,
@ -248,7 +247,7 @@ async def retrieve_fine_tuning_job(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -305,7 +304,6 @@ async def list_fine_tuning_jobs(
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
premium_user,
proxy_config,
proxy_logging_obj,
@ -349,7 +347,7 @@ async def list_fine_tuning_jobs(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -404,7 +402,6 @@ async def cancel_fine_tuning_job(
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
premium_user,
proxy_config,
proxy_logging_obj,
@ -451,7 +448,7 @@ async def cancel_fine_tuning_job(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,

View file

@ -25,8 +25,12 @@ class AimGuardrailMissingSecrets(Exception):
class AimGuardrail(CustomGuardrail):
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs):
self.async_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.GuardrailCallback)
def __init__(
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
):
self.async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.GuardrailCallback
)
self.api_key = api_key or os.environ.get("AIM_API_KEY")
if not self.api_key:
msg = (
@ -34,7 +38,9 @@ class AimGuardrail(CustomGuardrail):
"pass it as a parameter to the guardrail in the config file"
)
raise AimGuardrailMissingSecrets(msg)
self.api_base = api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security"
self.api_base = (
api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security"
)
super().__init__(**kwargs)
async def async_pre_call_hook(
@ -68,6 +74,7 @@ class AimGuardrail(CustomGuardrail):
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
) -> Union[Exception, str, dict, None]:
verbose_proxy_logger.debug("Inside AIM Moderation Hook")
@ -77,9 +84,10 @@ class AimGuardrail(CustomGuardrail):
async def call_aim_guardrail(self, data: dict, hook: str) -> None:
user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
headers = {"Authorization": f"Bearer {self.api_key}", "x-aim-litellm-hook": hook} | (
{"x-aim-user-email": user_email} if user_email else {}
)
headers = {
"Authorization": f"Bearer {self.api_key}",
"x-aim-litellm-hook": hook,
} | ({"x-aim-user-email": user_email} if user_email else {})
response = await self.async_handler.post(
f"{self.api_base}/detect/openai",
headers=headers,

View file

@ -178,7 +178,7 @@ class AporiaGuardrail(CustomGuardrail):
pass
@log_guardrail_information
async def async_moderation_hook( ### 👈 KEY CHANGE ###
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
@ -188,6 +188,7 @@ class AporiaGuardrail(CustomGuardrail):
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
):
from litellm.proxy.common_utils.callback_utils import (

View file

@ -240,7 +240,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
)
@log_guardrail_information
async def async_moderation_hook( ### 👈 KEY CHANGE ###
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
@ -250,6 +250,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
):
from litellm.proxy.common_utils.callback_utils import (

View file

@ -70,6 +70,7 @@ class myCustomGuardrail(CustomGuardrail):
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
):
"""

View file

@ -134,6 +134,7 @@ class lakeraAI_Moderation(CustomGuardrail):
"audio_transcription",
"pass_through_endpoint",
"rerank",
"responses",
],
):
if (
@ -335,7 +336,7 @@ class lakeraAI_Moderation(CustomGuardrail):
)
@log_guardrail_information
async def async_moderation_hook( ### 👈 KEY CHANGE ###
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
@ -345,6 +346,7 @@ class lakeraAI_Moderation(CustomGuardrail):
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
):
if self.event_hook is None:

View file

@ -62,10 +62,18 @@ def _get_metadata_variable_name(request: Request) -> str:
"""
if RouteChecks._is_assistants_api_request(request):
return "litellm_metadata"
if "batches" in request.url.path:
return "litellm_metadata"
if "/v1/messages" in request.url.path:
# anthropic API has a field called metadata
LITELLM_METADATA_ROUTES = [
"batches",
"/v1/messages",
"responses",
]
if any(
[
litellm_metadata_route in request.url.path
for litellm_metadata_route in LITELLM_METADATA_ROUTES
]
):
return "litellm_metadata"
else:
return "metadata"

View file

@ -27,6 +27,7 @@ from litellm import CreateFileRequest, get_secret_str
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
from litellm.proxy.common_utils.openai_endpoint_utils import (
get_custom_llm_provider_from_request_body,
)
@ -145,7 +146,6 @@ async def create_file(
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
llm_router,
proxy_config,
proxy_logging_obj,
@ -234,7 +234,7 @@ async def create_file(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -309,7 +309,6 @@ async def get_file_content(
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
proxy_config,
proxy_logging_obj,
version,
@ -351,7 +350,7 @@ async def get_file_content(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -437,7 +436,6 @@ async def get_file(
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
proxy_config,
proxy_logging_obj,
version,
@ -477,7 +475,7 @@ async def get_file(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -554,7 +552,6 @@ async def delete_file(
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
proxy_config,
proxy_logging_obj,
version,
@ -595,7 +592,7 @@ async def delete_file(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -671,7 +668,6 @@ async def list_files(
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
proxy_config,
proxy_logging_obj,
version,
@ -712,7 +708,7 @@ async def list_files(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,

View file

@ -23,6 +23,7 @@ from litellm.proxy._types import (
UserAPIKeyAuth,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.custom_http import httpxSpecialProvider
@ -106,7 +107,6 @@ async def chat_completion_pass_through_endpoint( # noqa: PLR0915
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
llm_router,
proxy_config,
proxy_logging_obj,
@ -231,7 +231,7 @@ async def chat_completion_pass_through_endpoint( # noqa: PLR0915
verbose_proxy_logger.debug("final response: %s", response)
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,

View file

@ -1,10 +1,6 @@
model_list:
- model_name: thinking-us.anthropic.claude-3-7-sonnet-20250219-v1:0
- model_name: gpt-4o
litellm_params:
model: bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0
thinking: {"type": "enabled", "budget_tokens": 1024}
max_tokens: 1080
merge_reasoning_content_in_choices: true
model: gpt-4o

View file

@ -139,12 +139,9 @@ from litellm.proxy.batches_endpoints.endpoints import router as batches_router
## Import All Misc routes here ##
from litellm.proxy.caching_routes import router as caching_router
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
from litellm.proxy.common_utils.admin_ui_utils import html_form
from litellm.proxy.common_utils.callback_utils import (
get_logging_caching_headers,
get_remaining_tokens_and_requests_from_request_data,
initialize_callbacks_on_proxy,
)
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
from litellm.proxy.common_utils.debug_utils import init_verbose_loggers
from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
@ -236,6 +233,7 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
router as pass_through_router,
)
from litellm.proxy.rerank_endpoints.endpoints import router as rerank_router
from litellm.proxy.response_api_endpoints.endpoints import router as response_router
from litellm.proxy.route_llm_request import route_request
from litellm.proxy.spend_tracking.spend_management_endpoints import (
router as spend_management_router,
@ -783,69 +781,6 @@ db_writer_client: Optional[AsyncHTTPHandler] = None
### logger ###
def get_custom_headers(
*,
user_api_key_dict: UserAPIKeyAuth,
call_id: Optional[str] = None,
model_id: Optional[str] = None,
cache_key: Optional[str] = None,
api_base: Optional[str] = None,
version: Optional[str] = None,
model_region: Optional[str] = None,
response_cost: Optional[Union[float, str]] = None,
hidden_params: Optional[dict] = None,
fastest_response_batch_completion: Optional[bool] = None,
request_data: Optional[dict] = {},
timeout: Optional[Union[float, int, httpx.Timeout]] = None,
**kwargs,
) -> dict:
exclude_values = {"", None, "None"}
hidden_params = hidden_params or {}
headers = {
"x-litellm-call-id": call_id,
"x-litellm-model-id": model_id,
"x-litellm-cache-key": cache_key,
"x-litellm-model-api-base": api_base,
"x-litellm-version": version,
"x-litellm-model-region": model_region,
"x-litellm-response-cost": str(response_cost),
"x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit),
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
"x-litellm-key-max-budget": str(user_api_key_dict.max_budget),
"x-litellm-key-spend": str(user_api_key_dict.spend),
"x-litellm-response-duration-ms": str(hidden_params.get("_response_ms", None)),
"x-litellm-overhead-duration-ms": str(
hidden_params.get("litellm_overhead_time_ms", None)
),
"x-litellm-fastest_response_batch_completion": (
str(fastest_response_batch_completion)
if fastest_response_batch_completion is not None
else None
),
"x-litellm-timeout": str(timeout) if timeout is not None else None,
**{k: str(v) for k, v in kwargs.items()},
}
if request_data:
remaining_tokens_header = get_remaining_tokens_and_requests_from_request_data(
request_data
)
headers.update(remaining_tokens_header)
logging_caching_headers = get_logging_caching_headers(request_data)
if logging_caching_headers:
headers.update(logging_caching_headers)
try:
return {
key: str(value)
for key, value in headers.items()
if value not in exclude_values
}
except Exception as e:
verbose_proxy_logger.error(f"Error setting custom headers: {e}")
return {}
async def check_request_disconnection(request: Request, llm_api_call_task):
"""
Asynchronously checks if the request is disconnected at regular intervals.
@ -3518,169 +3453,28 @@ async def chat_completion( # noqa: PLR0915
"""
global general_settings, user_debug, proxy_logging_obj, llm_model_list
data = {}
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
data = await _read_request_body(request=request)
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
try:
data = await _read_request_body(request=request)
verbose_proxy_logger.debug(
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
)
data = await add_litellm_data_to_request(
data=data,
return await base_llm_response_processor.base_process_llm_request(
request=request,
general_settings=general_settings,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or model # for azure deployments
or data.get("model", None) # default passed in http request
)
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
# override with user settings, these are params passed via cli
if user_temperature:
data["temperature"] = user_temperature
if user_request_timeout:
data["request_timeout"] = user_request_timeout
if user_max_tokens:
data["max_tokens"] = user_max_tokens
if user_api_base:
data["api_base"] = user_api_base
### MODEL ALIAS MAPPING ###
# check if model name in model alias map
# get the actual model name
if isinstance(data["model"], str) and data["model"] in litellm.model_alias_map:
data["model"] = litellm.model_alias_map[data["model"]]
### CALL HOOKS ### - modify/reject incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
)
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
## IMPORTANT Note: - initialize this before running pre-call checks. Ensures we log rejected requests to langfuse.
data["litellm_call_id"] = request.headers.get(
"x-litellm-call-id", str(uuid.uuid4())
)
logging_obj, data = litellm.utils.function_setup(
original_function="acompletion",
rules_obj=litellm.utils.Rules(),
start_time=datetime.now(),
**data,
)
data["litellm_logging_obj"] = logging_obj
tasks = []
tasks.append(
proxy_logging_obj.during_call_hook(
data=data,
user_api_key_dict=user_api_key_dict,
call_type="completion",
)
)
### ROUTE THE REQUEST ###
# Do not change this - it should be a constant time fetch - ALWAYS
llm_call = await route_request(
data=data,
route_type="acompletion",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=select_data_generator,
model=model,
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
tasks.append(llm_call)
# wait for call to end
llm_responses = asyncio.gather(
*tasks
) # run the moderation check in parallel to the actual llm api call
responses = await llm_responses
response = responses[1]
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
fastest_response_batch_completion = hidden_params.get(
"fastest_response_batch_completion", None
)
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
# Post Call Processing
if llm_router is not None:
data["deployment"] = llm_router.get_deployment(model_id=model_id)
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
if (
"stream" in data and data["stream"] is True
): # use generate_responses to stream responses
custom_headers = get_custom_headers(
user_api_key_dict=user_api_key_dict,
call_id=logging_obj.litellm_call_id,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
request_data=data,
hidden_params=hidden_params,
**additional_headers,
)
selected_data_generator = select_data_generator(
response=response,
user_api_key_dict=user_api_key_dict,
request_data=data,
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
headers=custom_headers,
)
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)
hidden_params = (
getattr(response, "_hidden_params", {}) or {}
) # get any updated response headers
additional_headers = hidden_params.get("additional_headers", {}) or {}
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
call_id=logging_obj.litellm_call_id,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
request_data=data,
hidden_params=hidden_params,
**additional_headers,
)
)
await check_response_size_is_safe(response=response)
return response
except RejectedRequestError as e:
_data = e.request_data
await proxy_logging_obj.post_call_failure_hook(
@ -3715,55 +3509,10 @@ async def chat_completion( # noqa: PLR0915
_chat_response.usage = _usage # type: ignore
return _chat_response
except Exception as e:
verbose_proxy_logger.exception(
f"litellm.proxy.proxy_server.chat_completion(): Exception occured - {str(e)}"
)
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
litellm_debug_info = getattr(e, "litellm_debug_info", "")
verbose_proxy_logger.debug(
"\033[1;31mAn error occurred: %s %s\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`",
e,
litellm_debug_info,
)
timeout = getattr(
e, "timeout", None
) # returns the timeout set by the wrapper. Used for testing if model-specific timeout are set correctly
_litellm_logging_obj: Optional[LiteLLMLoggingObj] = data.get(
"litellm_logging_obj", None
)
custom_headers = get_custom_headers(
raise await base_llm_response_processor._handle_llm_api_exception(
e=e,
user_api_key_dict=user_api_key_dict,
call_id=(
_litellm_logging_obj.litellm_call_id if _litellm_logging_obj else None
),
version=version,
response_cost=0,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
timeout=timeout,
)
headers = getattr(e, "headers", {}) or {}
headers.update(custom_headers)
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", str(e)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
headers=headers,
)
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
openai_code=getattr(e, "code", None),
code=getattr(e, "status_code", 500),
headers=headers,
proxy_logging_obj=proxy_logging_obj,
)
@ -3880,7 +3629,7 @@ async def completion( # noqa: PLR0915
if (
"stream" in data and data["stream"] is True
): # use generate_responses to stream responses
custom_headers = get_custom_headers(
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
call_id=litellm_call_id,
model_id=model_id,
@ -3908,7 +3657,7 @@ async def completion( # noqa: PLR0915
)
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
call_id=litellm_call_id,
model_id=model_id,
@ -4139,7 +3888,7 @@ async def embeddings( # noqa: PLR0915
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -4267,7 +4016,7 @@ async def image_generation(
litellm_call_id = hidden_params.get("litellm_call_id", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -4388,7 +4137,7 @@ async def audio_speech(
async for chunk in _generator:
yield chunk
custom_headers = get_custom_headers(
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -4529,7 +4278,7 @@ async def audio_transcriptions(
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -4681,7 +4430,7 @@ async def get_assistants(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -4780,7 +4529,7 @@ async def create_assistant(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -4877,7 +4626,7 @@ async def delete_assistant(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -4974,7 +4723,7 @@ async def create_threads(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -5070,7 +4819,7 @@ async def get_thread(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -5169,7 +4918,7 @@ async def add_messages(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -5264,7 +5013,7 @@ async def get_messages(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -5373,7 +5122,7 @@ async def run_thread(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -5496,7 +5245,7 @@ async def moderations(
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
@ -8640,6 +8389,7 @@ async def get_routes():
app.include_router(router)
app.include_router(response_router)
app.include_router(batches_router)
app.include_router(rerank_router)
app.include_router(fine_tuning_router)

View file

@ -7,10 +7,12 @@ from fastapi.responses import ORJSONResponse
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
router = APIRouter()
import asyncio
@router.post(
"/v2/rerank",
dependencies=[Depends(user_api_key_auth)],
@ -37,7 +39,6 @@ async def rerank(
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
llm_router,
proxy_config,
proxy_logging_obj,
@ -89,7 +90,7 @@ async def rerank(
api_base = hidden_params.get("api_base", None) or ""
additional_headers = hidden_params.get("additional_headers", None) or {}
fastapi_response.headers.update(
get_custom_headers(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,

View file

@ -0,0 +1,80 @@
from fastapi import APIRouter, Depends, Request, Response
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth, user_api_key_auth
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
router = APIRouter()
@router.post(
"/v1/responses",
dependencies=[Depends(user_api_key_auth)],
tags=["responses"],
)
@router.post(
"/responses",
dependencies=[Depends(user_api_key_auth)],
tags=["responses"],
)
async def responses_api(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Follows the OpenAI Responses API spec: https://platform.openai.com/docs/api-reference/responses
```bash
curl -X POST http://localhost:4000/v1/responses \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "gpt-4o",
"input": "Tell me about AI"
}'
```
"""
from litellm.proxy.proxy_server import (
_read_request_body,
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
select_data_generator,
user_api_base,
user_max_tokens,
user_model,
user_request_timeout,
user_temperature,
version,
)
data = await _read_request_body(request=request)
processor = ProxyBaseLLMRequestProcessing(data=data)
try:
return await processor.base_process_llm_request(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
route_type="aresponses",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=select_data_generator,
model=None,
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
except Exception as e:
raise await processor._handle_llm_api_exception(
e=e,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
version=version,
)

View file

@ -21,6 +21,7 @@ ROUTE_ENDPOINT_MAPPING = {
"atranscription": "/audio/transcriptions",
"amoderation": "/moderations",
"arerank": "/rerank",
"aresponses": "/responses",
}
@ -45,6 +46,7 @@ async def route_request(
"atranscription",
"amoderation",
"arerank",
"aresponses",
"_arealtime", # private function for realtime API
],
):

View file

@ -537,6 +537,7 @@ class ProxyLogging:
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal[
"completion",
"responses",
"embeddings",
"image_generation",
"moderation",

View file

@ -581,13 +581,7 @@ class Router:
self._initialize_alerting()
self.initialize_assistants_endpoint()
self.amoderation = self.factory_function(
litellm.amoderation, call_type="moderation"
)
self.aanthropic_messages = self.factory_function(
litellm.anthropic_messages, call_type="anthropic_messages"
)
self.initialize_router_endpoints()
def discard(self):
"""
@ -653,6 +647,18 @@ class Router:
self.aget_messages = self.factory_function(litellm.aget_messages)
self.arun_thread = self.factory_function(litellm.arun_thread)
def initialize_router_endpoints(self):
self.amoderation = self.factory_function(
litellm.amoderation, call_type="moderation"
)
self.aanthropic_messages = self.factory_function(
litellm.anthropic_messages, call_type="anthropic_messages"
)
self.aresponses = self.factory_function(
litellm.aresponses, call_type="aresponses"
)
self.responses = self.factory_function(litellm.responses, call_type="responses")
def routing_strategy_init(
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
):
@ -1080,17 +1086,22 @@ class Router:
kwargs.setdefault("litellm_trace_id", str(uuid.uuid4()))
kwargs.setdefault("metadata", {}).update({"model_group": model})
def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None:
def _update_kwargs_with_default_litellm_params(
self, kwargs: dict, metadata_variable_name: Optional[str] = "metadata"
) -> None:
"""
Adds default litellm params to kwargs, if set.
"""
self.default_litellm_params[metadata_variable_name] = (
self.default_litellm_params.pop("metadata", {})
)
for k, v in self.default_litellm_params.items():
if (
k not in kwargs and v is not None
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
elif k == metadata_variable_name:
kwargs[metadata_variable_name].update(v)
def _handle_clientside_credential(
self, deployment: dict, kwargs: dict
@ -1121,7 +1132,12 @@ class Router:
) # add new deployment to router
return deployment_pydantic_obj
def _update_kwargs_with_deployment(self, deployment: dict, kwargs: dict) -> None:
def _update_kwargs_with_deployment(
self,
deployment: dict,
kwargs: dict,
function_name: Optional[str] = None,
) -> None:
"""
2 jobs:
- Adds selected deployment, model_info and api_base to kwargs["metadata"] (used for logging)
@ -1138,7 +1154,10 @@ class Router:
deployment_model_name = deployment_pydantic_obj.litellm_params.model
deployment_api_base = deployment_pydantic_obj.litellm_params.api_base
kwargs.setdefault("metadata", {}).update(
metadata_variable_name = _get_router_metadata_variable_name(
function_name=function_name,
)
kwargs.setdefault(metadata_variable_name, {}).update(
{
"deployment": deployment_model_name,
"model_info": model_info,
@ -1151,7 +1170,9 @@ class Router:
kwargs=kwargs, data=deployment["litellm_params"]
)
self._update_kwargs_with_default_litellm_params(kwargs=kwargs)
self._update_kwargs_with_default_litellm_params(
kwargs=kwargs, metadata_variable_name=metadata_variable_name
)
def _get_async_openai_model_client(self, deployment: dict, kwargs: dict):
"""
@ -2396,22 +2417,18 @@ class Router:
messages=kwargs.get("messages", None),
specific_deployment=kwargs.pop("specific_deployment", None),
)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
self._update_kwargs_with_deployment(
deployment=deployment, kwargs=kwargs, function_name="generic_api_call"
)
data = deployment["litellm_params"].copy()
model_name = data["model"]
model_client = self._get_async_openai_model_client(
deployment=deployment,
kwargs=kwargs,
)
self.total_calls[model_name] += 1
response = original_function(
**{
**data,
"caching": self.cache_responses,
"client": model_client,
**kwargs,
}
)
@ -2453,6 +2470,61 @@ class Router:
self.fail_calls[model] += 1
raise e
def _generic_api_call_with_fallbacks(
self, model: str, original_function: Callable, **kwargs
):
"""
Make a generic LLM API call through the router, this allows you to use retries/fallbacks with litellm router
Args:
model: The model to use
original_function: The handler function to call (e.g., litellm.completion)
**kwargs: Additional arguments to pass to the handler function
Returns:
The response from the handler function
"""
handler_name = original_function.__name__
try:
verbose_router_logger.debug(
f"Inside _generic_api_call() - handler: {handler_name}, model: {model}; kwargs: {kwargs}"
)
deployment = self.get_available_deployment(
model=model,
messages=kwargs.get("messages", None),
specific_deployment=kwargs.pop("specific_deployment", None),
)
self._update_kwargs_with_deployment(
deployment=deployment, kwargs=kwargs, function_name="generic_api_call"
)
data = deployment["litellm_params"].copy()
model_name = data["model"]
self.total_calls[model_name] += 1
# Perform pre-call checks for routing strategy
self.routing_strategy_pre_call_checks(deployment=deployment)
response = original_function(
**{
**data,
"caching": self.cache_responses,
**kwargs,
}
)
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"{handler_name}(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
f"{handler_name}(model={model})\033[31m Exception {str(e)}\033[0m"
)
if model is not None:
self.fail_calls[model] += 1
raise e
def embedding(
self,
model: str,
@ -2974,14 +3046,42 @@ class Router:
self,
original_function: Callable,
call_type: Literal[
"assistants", "moderation", "anthropic_messages"
"assistants",
"moderation",
"anthropic_messages",
"aresponses",
"responses",
] = "assistants",
):
async def new_function(
"""
Creates appropriate wrapper functions for different API call types.
Returns:
- A synchronous function for synchronous call types
- An asynchronous function for asynchronous call types
"""
# Handle synchronous call types
if call_type == "responses":
def sync_wrapper(
custom_llm_provider: Optional[
Literal["openai", "azure", "anthropic"]
] = None,
client: Optional[Any] = None,
**kwargs,
):
return self._generic_api_call_with_fallbacks(
original_function=original_function, **kwargs
)
return sync_wrapper
# Handle asynchronous call types
async def async_wrapper(
custom_llm_provider: Optional[
Literal["openai", "azure", "anthropic"]
] = None,
client: Optional["AsyncOpenAI"] = None,
client: Optional[Any] = None,
**kwargs,
):
if call_type == "assistants":
@ -2992,18 +3092,16 @@ class Router:
**kwargs,
)
elif call_type == "moderation":
return await self._pass_through_moderation_endpoint_factory( # type: ignore
original_function=original_function,
**kwargs,
return await self._pass_through_moderation_endpoint_factory(
original_function=original_function, **kwargs
)
elif call_type == "anthropic_messages":
return await self._ageneric_api_call_with_fallbacks( # type: ignore
elif call_type in ("anthropic_messages", "aresponses"):
return await self._ageneric_api_call_with_fallbacks(
original_function=original_function,
**kwargs,
)
return new_function
return async_wrapper
async def _pass_through_assistants_endpoint_factory(
self,

View file

@ -56,7 +56,8 @@ def _get_router_metadata_variable_name(function_name) -> str:
For ALL other endpoints we call this "metadata
"""
if "batch" in function_name:
ROUTER_METHODS_USING_LITELLM_METADATA = set(["batch", "generic_api_call"])
if function_name in ROUTER_METHODS_USING_LITELLM_METADATA:
return "litellm_metadata"
else:
return "metadata"

View file

@ -742,6 +742,9 @@ class BaseLiteLLMOpenAIResponseObject(BaseModel):
def __contains__(self, key):
return key in self.__dict__
def items(self):
return self.__dict__.items()
class OutputTokensDetails(BaseLiteLLMOpenAIResponseObject):
reasoning_tokens: int

View file

@ -230,6 +230,7 @@ def test_select_azure_base_url_called(setup_mocks):
"anthropic_messages",
"add_message",
"arun_thread_stream",
"aresponses",
]
],
)

View file

@ -3,6 +3,7 @@ import sys
import pytest
import asyncio
from typing import Optional
from unittest.mock import patch, AsyncMock
sys.path.insert(0, os.path.abspath("../.."))
import litellm
@ -16,6 +17,7 @@ from litellm.types.llms.openai import (
ResponseAPIUsage,
IncompleteDetails,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
def validate_responses_api_response(response, final_chunk: bool = False):
@ -503,3 +505,293 @@ async def test_openai_responses_api_streaming_validation(sync_mode):
assert not missing_events, f"Missing required event types: {missing_events}"
print(f"Successfully validated all event types: {event_types_seen}")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_openai_responses_litellm_router(sync_mode):
"""
Test the OpenAI responses API with LiteLLM Router in both sync and async modes
"""
litellm._turn_on_debug()
router = litellm.Router(
model_list=[
{
"model_name": "gpt4o-special-alias",
"litellm_params": {
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
}
]
)
# Call the handler
if sync_mode:
response = router.responses(
model="gpt4o-special-alias",
input="Hello, can you tell me a short joke?",
max_output_tokens=100,
)
print("SYNC MODE RESPONSE=", response)
else:
response = await router.aresponses(
model="gpt4o-special-alias",
input="Hello, can you tell me a short joke?",
max_output_tokens=100,
)
print(
f"Router {'sync' if sync_mode else 'async'} response=",
json.dumps(response, indent=4, default=str),
)
# Use the helper function to validate the response
validate_responses_api_response(response, final_chunk=True)
return response
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_openai_responses_litellm_router_streaming(sync_mode):
"""
Test the OpenAI responses API with streaming through LiteLLM Router
"""
litellm._turn_on_debug()
router = litellm.Router(
model_list=[
{
"model_name": "gpt4o-special-alias",
"litellm_params": {
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
}
]
)
event_types_seen = set()
if sync_mode:
response = router.responses(
model="gpt4o-special-alias",
input="Tell me about artificial intelligence in 2 sentences.",
stream=True,
)
for event in response:
print(f"Validating event type: {event.type}")
validate_stream_event(event)
event_types_seen.add(event.type)
else:
response = await router.aresponses(
model="gpt4o-special-alias",
input="Tell me about artificial intelligence in 2 sentences.",
stream=True,
)
async for event in response:
print(f"Validating event type: {event.type}")
validate_stream_event(event)
event_types_seen.add(event.type)
# At minimum, we should see these core event types
required_events = {"response.created", "response.completed"}
missing_events = required_events - event_types_seen
assert not missing_events, f"Missing required event types: {missing_events}"
print(f"Successfully validated all event types: {event_types_seen}")
@pytest.mark.asyncio
async def test_openai_responses_litellm_router_no_metadata():
"""
Test that metadata is not passed through when using the Router for responses API
"""
mock_response = {
"id": "resp_123",
"object": "response",
"created_at": 1741476542,
"status": "completed",
"model": "gpt-4o",
"output": [
{
"type": "message",
"id": "msg_123",
"status": "completed",
"role": "assistant",
"content": [
{"type": "output_text", "text": "Hello world!", "annotations": []}
],
}
],
"parallel_tool_calls": True,
"usage": {
"input_tokens": 10,
"output_tokens": 20,
"total_tokens": 30,
"output_tokens_details": {"reasoning_tokens": 0},
},
"text": {"format": {"type": "text"}},
# Adding all required fields
"error": None,
"incomplete_details": None,
"instructions": None,
"metadata": {},
"temperature": 1.0,
"tool_choice": "auto",
"tools": [],
"top_p": 1.0,
"max_output_tokens": None,
"previous_response_id": None,
"reasoning": {"effort": None, "summary": None},
"truncation": "disabled",
"user": None,
}
class MockResponse:
def __init__(self, json_data, status_code):
self._json_data = json_data
self.status_code = status_code
self.text = str(json_data)
def json(self): # Changed from async to sync
return self._json_data
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
new_callable=AsyncMock,
) as mock_post:
# Configure the mock to return our response
mock_post.return_value = MockResponse(mock_response, 200)
litellm._turn_on_debug()
router = litellm.Router(
model_list=[
{
"model_name": "gpt4o-special-alias",
"litellm_params": {
"model": "gpt-4o",
"api_key": "fake-key",
},
}
]
)
# Call the handler with metadata
await router.aresponses(
model="gpt4o-special-alias",
input="Hello, can you tell me a short joke?",
)
# Check the request body
request_body = mock_post.call_args.kwargs["data"]
print("Request body:", json.dumps(request_body, indent=4))
loaded_request_body = json.loads(request_body)
print("Loaded request body:", json.dumps(loaded_request_body, indent=4))
# Assert metadata is not in the request
assert (
loaded_request_body["metadata"] == None
), "metadata should not be in the request body"
mock_post.assert_called_once()
@pytest.mark.asyncio
async def test_openai_responses_litellm_router_with_metadata():
"""
Test that metadata is correctly passed through when explicitly provided to the Router for responses API
"""
test_metadata = {
"user_id": "123",
"conversation_id": "abc",
"custom_field": "test_value",
}
mock_response = {
"id": "resp_123",
"object": "response",
"created_at": 1741476542,
"status": "completed",
"model": "gpt-4o",
"output": [
{
"type": "message",
"id": "msg_123",
"status": "completed",
"role": "assistant",
"content": [
{"type": "output_text", "text": "Hello world!", "annotations": []}
],
}
],
"parallel_tool_calls": True,
"usage": {
"input_tokens": 10,
"output_tokens": 20,
"total_tokens": 30,
"output_tokens_details": {"reasoning_tokens": 0},
},
"text": {"format": {"type": "text"}},
"error": None,
"incomplete_details": None,
"instructions": None,
"metadata": test_metadata, # Include the test metadata in response
"temperature": 1.0,
"tool_choice": "auto",
"tools": [],
"top_p": 1.0,
"max_output_tokens": None,
"previous_response_id": None,
"reasoning": {"effort": None, "summary": None},
"truncation": "disabled",
"user": None,
}
class MockResponse:
def __init__(self, json_data, status_code):
self._json_data = json_data
self.status_code = status_code
self.text = str(json_data)
def json(self):
return self._json_data
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
new_callable=AsyncMock,
) as mock_post:
# Configure the mock to return our response
mock_post.return_value = MockResponse(mock_response, 200)
litellm._turn_on_debug()
router = litellm.Router(
model_list=[
{
"model_name": "gpt4o-special-alias",
"litellm_params": {
"model": "gpt-4o",
"api_key": "fake-key",
},
}
]
)
# Call the handler with metadata
await router.aresponses(
model="gpt4o-special-alias",
input="Hello, can you tell me a short joke?",
metadata=test_metadata,
)
# Check the request body
request_body = mock_post.call_args.kwargs["data"]
loaded_request_body = json.loads(request_body)
print("Request body:", json.dumps(loaded_request_body, indent=4))
# Assert metadata matches exactly what was passed
assert (
loaded_request_body["metadata"] == test_metadata
), "metadata in request body should match what was passed"
mock_post.assert_called_once()

View file

@ -315,14 +315,20 @@ async def test_router_with_empty_choices(model_list):
assert response is not None
@pytest.mark.asyncio
async def test_ageneric_api_call_with_fallbacks_basic():
@pytest.mark.parametrize("sync_mode", [True, False])
def test_generic_api_call_with_fallbacks_basic(sync_mode):
"""
Test the _ageneric_api_call_with_fallbacks method with a basic successful call
Test both the sync and async versions of generic_api_call_with_fallbacks with a basic successful call
"""
# Create a mock function that will be passed to _ageneric_api_call_with_fallbacks
mock_function = AsyncMock()
mock_function.__name__ = "test_function"
# Create a mock function that will be passed to generic_api_call_with_fallbacks
if sync_mode:
from unittest.mock import Mock
mock_function = Mock()
mock_function.__name__ = "test_function"
else:
mock_function = AsyncMock()
mock_function.__name__ = "test_function"
# Create a mock response
mock_response = {
@ -347,13 +353,23 @@ async def test_ageneric_api_call_with_fallbacks_basic():
]
)
# Call the _ageneric_api_call_with_fallbacks method
response = await router._ageneric_api_call_with_fallbacks(
model="test-model-alias",
original_function=mock_function,
messages=[{"role": "user", "content": "Hello"}],
max_tokens=100,
)
# Call the appropriate generic_api_call_with_fallbacks method
if sync_mode:
response = router._generic_api_call_with_fallbacks(
model="test-model-alias",
original_function=mock_function,
messages=[{"role": "user", "content": "Hello"}],
max_tokens=100,
)
else:
response = asyncio.run(
router._ageneric_api_call_with_fallbacks(
model="test-model-alias",
original_function=mock_function,
messages=[{"role": "user", "content": "Hello"}],
max_tokens=100,
)
)
# Verify the mock function was called
mock_function.assert_called_once()
@ -510,3 +526,36 @@ async def test__aadapter_completion():
# Verify async_routing_strategy_pre_call_checks was called
router.async_routing_strategy_pre_call_checks.assert_called_once()
def test_initialize_router_endpoints():
"""
Test that initialize_router_endpoints correctly sets up all router endpoints
"""
# Create a router with a basic model
router = Router(
model_list=[
{
"model_name": "test-model",
"litellm_params": {
"model": "anthropic/test-model",
"api_key": "fake-api-key",
},
}
]
)
# Explicitly call initialize_router_endpoints
router.initialize_router_endpoints()
# Verify all expected endpoints are initialized
assert hasattr(router, "amoderation")
assert hasattr(router, "aanthropic_messages")
assert hasattr(router, "aresponses")
assert hasattr(router, "responses")
# Verify the endpoints are callable
assert callable(router.amoderation)
assert callable(router.aanthropic_messages)
assert callable(router.aresponses)
assert callable(router.responses)