mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Merge pull request #9183 from BerriAI/litellm_router_responses_api_2
[Feat] - Add Responses API on LiteLLM Proxy
This commit is contained in:
commit
1d31e25816
44 changed files with 1165 additions and 401 deletions
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# [BETA] `/v1/messages`
|
||||
# /v1/messages [BETA]
|
||||
|
||||
LiteLLM provides a BETA endpoint in the spec of Anthropic's `/v1/messages` endpoint.
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Assistants API
|
||||
# /assistants
|
||||
|
||||
Covers Threads, Messages, Assistants.
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# [BETA] Batches API
|
||||
# /batches
|
||||
|
||||
Covers Batches, Files
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Embeddings
|
||||
# /embeddings
|
||||
|
||||
## Quick Start
|
||||
```python
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import TabItem from '@theme/TabItem';
|
||||
import Tabs from '@theme/Tabs';
|
||||
|
||||
# Files API
|
||||
# /files
|
||||
|
||||
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# [Beta] Fine-tuning API
|
||||
# /fine_tuning
|
||||
|
||||
|
||||
:::info
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Moderation
|
||||
# /moderations
|
||||
|
||||
|
||||
### Usage
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Realtime Endpoints
|
||||
# /realtime
|
||||
|
||||
Use this to loadbalance across Azure + OpenAI.
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Rerank
|
||||
# /rerank
|
||||
|
||||
:::tip
|
||||
|
||||
|
|
117
docs/my-website/docs/response_api.md
Normal file
117
docs/my-website/docs/response_api.md
Normal file
|
@ -0,0 +1,117 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# /responses
|
||||
|
||||
LiteLLM provides a BETA endpoint in the spec of [OpenAI's `/responses` API](https://platform.openai.com/docs/api-reference/responses)
|
||||
|
||||
| Feature | Supported | Notes |
|
||||
|---------|-----------|--------|
|
||||
| Cost Tracking | ✅ | Works with all supported models |
|
||||
| Logging | ✅ | Works across all integrations |
|
||||
| End-user Tracking | ✅ | |
|
||||
| Streaming | ✅ | |
|
||||
| Fallbacks | ✅ | Works between supported models |
|
||||
| Loadbalancing | ✅ | Works between supported models |
|
||||
| Supported LiteLLM Versions | 1.63.8+ | |
|
||||
| Supported LLM providers | `openai` | |
|
||||
|
||||
## Usage
|
||||
|
||||
## Create a model response
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="litellm-sdk" label="LiteLLM SDK">
|
||||
|
||||
#### Non-streaming
|
||||
```python
|
||||
import litellm
|
||||
|
||||
# Non-streaming response
|
||||
response = litellm.responses(
|
||||
model="gpt-4o",
|
||||
input="Tell me a three sentence bedtime story about a unicorn.",
|
||||
max_output_tokens=100
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
#### Streaming
|
||||
```python
|
||||
import litellm
|
||||
|
||||
# Streaming response
|
||||
response = litellm.responses(
|
||||
model="gpt-4o",
|
||||
input="Tell me a three sentence bedtime story about a unicorn.",
|
||||
stream=True
|
||||
)
|
||||
|
||||
for event in response:
|
||||
print(event)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="OpenAI SDK with LiteLLM Proxy">
|
||||
|
||||
First, add this to your litellm proxy config.yaml:
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-4o
|
||||
litellm_params:
|
||||
model: openai/gpt-4
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
```
|
||||
|
||||
Start your LiteLLM proxy:
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING on http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
Then use the OpenAI SDK pointed to your proxy:
|
||||
|
||||
#### Non-streaming
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# Initialize client with your proxy URL
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:4000", # Your proxy URL
|
||||
api_key="your-api-key" # Your proxy API key
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
response = client.responses.create(
|
||||
model="gpt-4o",
|
||||
input="Tell me a three sentence bedtime story about a unicorn."
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
#### Streaming
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# Initialize client with your proxy URL
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:4000", # Your proxy URL
|
||||
api_key="your-api-key" # Your proxy API key
|
||||
)
|
||||
|
||||
# Streaming response
|
||||
response = client.responses.create(
|
||||
model="gpt-4o",
|
||||
input="Tell me a three sentence bedtime story about a unicorn.",
|
||||
stream=True
|
||||
)
|
||||
|
||||
for event in response:
|
||||
print(event)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
|
@ -1,7 +1,7 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Text Completion
|
||||
# /completions
|
||||
|
||||
### Usage
|
||||
<Tabs>
|
||||
|
|
|
@ -273,7 +273,7 @@ const sidebars = {
|
|||
items: [
|
||||
{
|
||||
type: "category",
|
||||
label: "Chat",
|
||||
label: "/chat/completions",
|
||||
link: {
|
||||
type: "generated-index",
|
||||
title: "Chat Completions",
|
||||
|
@ -286,12 +286,13 @@ const sidebars = {
|
|||
"completion/usage",
|
||||
],
|
||||
},
|
||||
"response_api",
|
||||
"text_completion",
|
||||
"embedding/supported_embedding",
|
||||
"anthropic_unified",
|
||||
{
|
||||
type: "category",
|
||||
label: "Image",
|
||||
label: "/images",
|
||||
items: [
|
||||
"image_generation",
|
||||
"image_variations",
|
||||
|
@ -299,7 +300,7 @@ const sidebars = {
|
|||
},
|
||||
{
|
||||
type: "category",
|
||||
label: "Audio",
|
||||
label: "/audio",
|
||||
"items": [
|
||||
"audio_transcription",
|
||||
"text_to_speech",
|
||||
|
|
|
@ -163,7 +163,7 @@ class AporiaGuardrail(CustomGuardrail):
|
|||
|
||||
pass
|
||||
|
||||
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
|
@ -173,6 +173,7 @@ class AporiaGuardrail(CustomGuardrail):
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
|
|
|
@ -94,6 +94,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -107,6 +107,7 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -126,6 +126,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -31,7 +31,7 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
|
|||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
|
||||
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
|
@ -41,6 +41,7 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
text = ""
|
||||
|
|
|
@ -239,6 +239,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
) -> Any:
|
||||
pass
|
||||
|
|
|
@ -14,6 +14,7 @@ import litellm
|
|||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
@ -89,7 +90,6 @@ async def anthropic_response( # noqa: PLR0915
|
|||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
|
@ -205,7 +205,7 @@ async def anthropic_response( # noqa: PLR0915
|
|||
verbose_proxy_logger.debug("final response: %s", response)
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
|
|
@ -18,6 +18,7 @@ from litellm.batches.main import (
|
|||
)
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||
get_custom_llm_provider_from_request_body,
|
||||
|
@ -69,7 +70,6 @@ async def create_batch(
|
|||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
|
@ -137,7 +137,7 @@ async def create_batch(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -201,7 +201,6 @@ async def retrieve_batch(
|
|||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
|
@ -266,7 +265,7 @@ async def retrieve_batch(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -326,11 +325,7 @@ async def list_batches(
|
|||
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
get_custom_headers,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
)
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj, version
|
||||
|
||||
verbose_proxy_logger.debug("GET /v1/batches after={} limit={}".format(after, limit))
|
||||
try:
|
||||
|
@ -352,7 +347,7 @@ async def list_batches(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -417,7 +412,6 @@ async def cancel_batch(
|
|||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
|
@ -463,7 +457,7 @@ async def cancel_batch(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
|
356
litellm/proxy/common_request_processing.py
Normal file
356
litellm/proxy/common_request_processing.py
Normal file
|
@ -0,0 +1,356 @@
|
|||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.proxy._types import ProxyException, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
get_logging_caching_headers,
|
||||
get_remaining_tokens_and_requests_from_request_data,
|
||||
)
|
||||
from litellm.proxy.route_llm_request import route_request
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.router import Router
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
|
||||
|
||||
ProxyConfig = _ProxyConfig
|
||||
else:
|
||||
ProxyConfig = Any
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
|
||||
|
||||
class ProxyBaseLLMRequestProcessing:
|
||||
def __init__(self, data: dict):
|
||||
self.data = data
|
||||
|
||||
@staticmethod
|
||||
def get_custom_headers(
|
||||
*,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_id: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
cache_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
model_region: Optional[str] = None,
|
||||
response_cost: Optional[Union[float, str]] = None,
|
||||
hidden_params: Optional[dict] = None,
|
||||
fastest_response_batch_completion: Optional[bool] = None,
|
||||
request_data: Optional[dict] = {},
|
||||
timeout: Optional[Union[float, int, httpx.Timeout]] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
exclude_values = {"", None, "None"}
|
||||
hidden_params = hidden_params or {}
|
||||
headers = {
|
||||
"x-litellm-call-id": call_id,
|
||||
"x-litellm-model-id": model_id,
|
||||
"x-litellm-cache-key": cache_key,
|
||||
"x-litellm-model-api-base": api_base,
|
||||
"x-litellm-version": version,
|
||||
"x-litellm-model-region": model_region,
|
||||
"x-litellm-response-cost": str(response_cost),
|
||||
"x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit),
|
||||
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
|
||||
"x-litellm-key-max-budget": str(user_api_key_dict.max_budget),
|
||||
"x-litellm-key-spend": str(user_api_key_dict.spend),
|
||||
"x-litellm-response-duration-ms": str(
|
||||
hidden_params.get("_response_ms", None)
|
||||
),
|
||||
"x-litellm-overhead-duration-ms": str(
|
||||
hidden_params.get("litellm_overhead_time_ms", None)
|
||||
),
|
||||
"x-litellm-fastest_response_batch_completion": (
|
||||
str(fastest_response_batch_completion)
|
||||
if fastest_response_batch_completion is not None
|
||||
else None
|
||||
),
|
||||
"x-litellm-timeout": str(timeout) if timeout is not None else None,
|
||||
**{k: str(v) for k, v in kwargs.items()},
|
||||
}
|
||||
if request_data:
|
||||
remaining_tokens_header = (
|
||||
get_remaining_tokens_and_requests_from_request_data(request_data)
|
||||
)
|
||||
headers.update(remaining_tokens_header)
|
||||
|
||||
logging_caching_headers = get_logging_caching_headers(request_data)
|
||||
if logging_caching_headers:
|
||||
headers.update(logging_caching_headers)
|
||||
|
||||
try:
|
||||
return {
|
||||
key: str(value)
|
||||
for key, value in headers.items()
|
||||
if value not in exclude_values
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error setting custom headers: {e}")
|
||||
return {}
|
||||
|
||||
async def base_process_llm_request(
|
||||
self,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
route_type: Literal["acompletion", "aresponses"],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
general_settings: dict,
|
||||
proxy_config: ProxyConfig,
|
||||
select_data_generator: Callable,
|
||||
llm_router: Optional[Router] = None,
|
||||
model: Optional[str] = None,
|
||||
user_model: Optional[str] = None,
|
||||
user_temperature: Optional[float] = None,
|
||||
user_request_timeout: Optional[float] = None,
|
||||
user_max_tokens: Optional[int] = None,
|
||||
user_api_base: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Common request processing logic for both chat completions and responses API endpoints
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
"Request received by LiteLLM:\n{}".format(json.dumps(self.data, indent=4)),
|
||||
)
|
||||
|
||||
self.data = await add_litellm_data_to_request(
|
||||
data=self.data,
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
self.data["model"] = (
|
||||
general_settings.get("completion_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
or model # for azure deployments
|
||||
or self.data.get("model", None) # default passed in http request
|
||||
)
|
||||
|
||||
# override with user settings, these are params passed via cli
|
||||
if user_temperature:
|
||||
self.data["temperature"] = user_temperature
|
||||
if user_request_timeout:
|
||||
self.data["request_timeout"] = user_request_timeout
|
||||
if user_max_tokens:
|
||||
self.data["max_tokens"] = user_max_tokens
|
||||
if user_api_base:
|
||||
self.data["api_base"] = user_api_base
|
||||
|
||||
### MODEL ALIAS MAPPING ###
|
||||
# check if model name in model alias map
|
||||
# get the actual model name
|
||||
if (
|
||||
isinstance(self.data["model"], str)
|
||||
and self.data["model"] in litellm.model_alias_map
|
||||
):
|
||||
self.data["model"] = litellm.model_alias_map[self.data["model"]]
|
||||
|
||||
### CALL HOOKS ### - modify/reject incoming data before calling the model
|
||||
self.data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
||||
user_api_key_dict=user_api_key_dict, data=self.data, call_type="completion"
|
||||
)
|
||||
|
||||
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
|
||||
## IMPORTANT Note: - initialize this before running pre-call checks. Ensures we log rejected requests to langfuse.
|
||||
self.data["litellm_call_id"] = request.headers.get(
|
||||
"x-litellm-call-id", str(uuid.uuid4())
|
||||
)
|
||||
logging_obj, self.data = litellm.utils.function_setup(
|
||||
original_function=route_type,
|
||||
rules_obj=litellm.utils.Rules(),
|
||||
start_time=datetime.now(),
|
||||
**self.data,
|
||||
)
|
||||
|
||||
self.data["litellm_logging_obj"] = logging_obj
|
||||
|
||||
tasks = []
|
||||
tasks.append(
|
||||
proxy_logging_obj.during_call_hook(
|
||||
data=self.data,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_type=ProxyBaseLLMRequestProcessing._get_pre_call_type(
|
||||
route_type=route_type
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
### ROUTE THE REQUEST ###
|
||||
# Do not change this - it should be a constant time fetch - ALWAYS
|
||||
llm_call = await route_request(
|
||||
data=self.data,
|
||||
route_type=route_type,
|
||||
llm_router=llm_router,
|
||||
user_model=user_model,
|
||||
)
|
||||
tasks.append(llm_call)
|
||||
|
||||
# wait for call to end
|
||||
llm_responses = asyncio.gather(
|
||||
*tasks
|
||||
) # run the moderation check in parallel to the actual llm api call
|
||||
|
||||
responses = await llm_responses
|
||||
|
||||
response = responses[1]
|
||||
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
response_cost = hidden_params.get("response_cost", None) or ""
|
||||
fastest_response_batch_completion = hidden_params.get(
|
||||
"fastest_response_batch_completion", None
|
||||
)
|
||||
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
|
||||
|
||||
# Post Call Processing
|
||||
if llm_router is not None:
|
||||
self.data["deployment"] = llm_router.get_deployment(model_id=model_id)
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.update_request_status(
|
||||
litellm_call_id=self.data.get("litellm_call_id", ""), status="success"
|
||||
)
|
||||
)
|
||||
if (
|
||||
"stream" in self.data and self.data["stream"] is True
|
||||
): # use generate_responses to stream responses
|
||||
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_id=logging_obj.litellm_call_id,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
response_cost=response_cost,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||
request_data=self.data,
|
||||
hidden_params=hidden_params,
|
||||
**additional_headers,
|
||||
)
|
||||
selected_data_generator = select_data_generator(
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=self.data,
|
||||
)
|
||||
return StreamingResponse(
|
||||
selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
headers=custom_headers,
|
||||
)
|
||||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
data=self.data, user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
|
||||
hidden_params = (
|
||||
getattr(response, "_hidden_params", {}) or {}
|
||||
) # get any updated response headers
|
||||
additional_headers = hidden_params.get("additional_headers", {}) or {}
|
||||
|
||||
fastapi_response.headers.update(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_id=logging_obj.litellm_call_id,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
response_cost=response_cost,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||
request_data=self.data,
|
||||
hidden_params=hidden_params,
|
||||
**additional_headers,
|
||||
)
|
||||
)
|
||||
await check_response_size_is_safe(response=response)
|
||||
|
||||
return response
|
||||
|
||||
async def _handle_llm_api_exception(
|
||||
self,
|
||||
e: Exception,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
version: Optional[str] = None,
|
||||
):
|
||||
"""Raises ProxyException (OpenAI API compatible) if an exception is raised"""
|
||||
verbose_proxy_logger.exception(
|
||||
f"litellm.proxy.proxy_server._handle_llm_api_exception(): Exception occured - {str(e)}"
|
||||
)
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=self.data,
|
||||
)
|
||||
litellm_debug_info = getattr(e, "litellm_debug_info", "")
|
||||
verbose_proxy_logger.debug(
|
||||
"\033[1;31mAn error occurred: %s %s\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`",
|
||||
e,
|
||||
litellm_debug_info,
|
||||
)
|
||||
|
||||
timeout = getattr(
|
||||
e, "timeout", None
|
||||
) # returns the timeout set by the wrapper. Used for testing if model-specific timeout are set correctly
|
||||
_litellm_logging_obj: Optional[LiteLLMLoggingObj] = self.data.get(
|
||||
"litellm_logging_obj", None
|
||||
)
|
||||
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_id=(
|
||||
_litellm_logging_obj.litellm_call_id if _litellm_logging_obj else None
|
||||
),
|
||||
version=version,
|
||||
response_cost=0,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=self.data,
|
||||
timeout=timeout,
|
||||
)
|
||||
headers = getattr(e, "headers", {}) or {}
|
||||
headers.update(custom_headers)
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", str(e)),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
headers=headers,
|
||||
)
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
openai_code=getattr(e, "code", None),
|
||||
code=getattr(e, "status_code", 500),
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_pre_call_type(
|
||||
route_type: Literal["acompletion", "aresponses"]
|
||||
) -> Literal["completion", "responses"]:
|
||||
if route_type == "acompletion":
|
||||
return "completion"
|
||||
elif route_type == "aresponses":
|
||||
return "responses"
|
|
@ -61,6 +61,7 @@ class MyCustomHandler(
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
pass
|
||||
|
|
|
@ -66,6 +66,7 @@ class myCustomGuardrail(CustomGuardrail):
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -15,6 +15,7 @@ import litellm
|
|||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
from litellm.proxy.utils import handle_exception_on_proxy
|
||||
|
||||
router = APIRouter()
|
||||
|
@ -97,7 +98,6 @@ async def create_fine_tuning_job(
|
|||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
premium_user,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
|
@ -151,7 +151,7 @@ async def create_fine_tuning_job(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -205,7 +205,6 @@ async def retrieve_fine_tuning_job(
|
|||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
premium_user,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
|
@ -248,7 +247,7 @@ async def retrieve_fine_tuning_job(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -305,7 +304,6 @@ async def list_fine_tuning_jobs(
|
|||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
premium_user,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
|
@ -349,7 +347,7 @@ async def list_fine_tuning_jobs(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -404,7 +402,6 @@ async def cancel_fine_tuning_job(
|
|||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
premium_user,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
|
@ -451,7 +448,7 @@ async def cancel_fine_tuning_job(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
|
|
@ -25,8 +25,12 @@ class AimGuardrailMissingSecrets(Exception):
|
|||
|
||||
|
||||
class AimGuardrail(CustomGuardrail):
|
||||
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs):
|
||||
self.async_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.GuardrailCallback)
|
||||
def __init__(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
|
||||
):
|
||||
self.async_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.GuardrailCallback
|
||||
)
|
||||
self.api_key = api_key or os.environ.get("AIM_API_KEY")
|
||||
if not self.api_key:
|
||||
msg = (
|
||||
|
@ -34,7 +38,9 @@ class AimGuardrail(CustomGuardrail):
|
|||
"pass it as a parameter to the guardrail in the config file"
|
||||
)
|
||||
raise AimGuardrailMissingSecrets(msg)
|
||||
self.api_base = api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security"
|
||||
self.api_base = (
|
||||
api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security"
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def async_pre_call_hook(
|
||||
|
@ -68,6 +74,7 @@ class AimGuardrail(CustomGuardrail):
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
) -> Union[Exception, str, dict, None]:
|
||||
verbose_proxy_logger.debug("Inside AIM Moderation Hook")
|
||||
|
@ -77,9 +84,10 @@ class AimGuardrail(CustomGuardrail):
|
|||
|
||||
async def call_aim_guardrail(self, data: dict, hook: str) -> None:
|
||||
user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "x-aim-litellm-hook": hook} | (
|
||||
{"x-aim-user-email": user_email} if user_email else {}
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"x-aim-litellm-hook": hook,
|
||||
} | ({"x-aim-user-email": user_email} if user_email else {})
|
||||
response = await self.async_handler.post(
|
||||
f"{self.api_base}/detect/openai",
|
||||
headers=headers,
|
||||
|
|
|
@ -178,7 +178,7 @@ class AporiaGuardrail(CustomGuardrail):
|
|||
pass
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
|
@ -188,6 +188,7 @@ class AporiaGuardrail(CustomGuardrail):
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
|
|
|
@ -240,7 +240,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
|||
)
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
|
@ -250,6 +250,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
|
|
|
@ -70,6 +70,7 @@ class myCustomGuardrail(CustomGuardrail):
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -134,6 +134,7 @@ class lakeraAI_Moderation(CustomGuardrail):
|
|||
"audio_transcription",
|
||||
"pass_through_endpoint",
|
||||
"rerank",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
if (
|
||||
|
@ -335,7 +336,7 @@ class lakeraAI_Moderation(CustomGuardrail):
|
|||
)
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
|
@ -345,6 +346,7 @@ class lakeraAI_Moderation(CustomGuardrail):
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
if self.event_hook is None:
|
||||
|
|
|
@ -62,10 +62,18 @@ def _get_metadata_variable_name(request: Request) -> str:
|
|||
"""
|
||||
if RouteChecks._is_assistants_api_request(request):
|
||||
return "litellm_metadata"
|
||||
if "batches" in request.url.path:
|
||||
return "litellm_metadata"
|
||||
if "/v1/messages" in request.url.path:
|
||||
# anthropic API has a field called metadata
|
||||
|
||||
LITELLM_METADATA_ROUTES = [
|
||||
"batches",
|
||||
"/v1/messages",
|
||||
"responses",
|
||||
]
|
||||
if any(
|
||||
[
|
||||
litellm_metadata_route in request.url.path
|
||||
for litellm_metadata_route in LITELLM_METADATA_ROUTES
|
||||
]
|
||||
):
|
||||
return "litellm_metadata"
|
||||
else:
|
||||
return "metadata"
|
||||
|
|
|
@ -27,6 +27,7 @@ from litellm import CreateFileRequest, get_secret_str
|
|||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||
get_custom_llm_provider_from_request_body,
|
||||
)
|
||||
|
@ -145,7 +146,6 @@ async def create_file(
|
|||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
|
@ -234,7 +234,7 @@ async def create_file(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -309,7 +309,6 @@ async def get_file_content(
|
|||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
|
@ -351,7 +350,7 @@ async def get_file_content(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -437,7 +436,6 @@ async def get_file(
|
|||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
|
@ -477,7 +475,7 @@ async def get_file(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -554,7 +552,6 @@ async def delete_file(
|
|||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
|
@ -595,7 +592,7 @@ async def delete_file(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -671,7 +668,6 @@ async def list_files(
|
|||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
|
@ -712,7 +708,7 @@ async def list_files(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
|
|
@ -23,6 +23,7 @@ from litellm.proxy._types import (
|
|||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
|
@ -106,7 +107,6 @@ async def chat_completion_pass_through_endpoint( # noqa: PLR0915
|
|||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
|
@ -231,7 +231,7 @@ async def chat_completion_pass_through_endpoint( # noqa: PLR0915
|
|||
verbose_proxy_logger.debug("final response: %s", response)
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
|
|
@ -1,10 +1,6 @@
|
|||
model_list:
|
||||
- model_name: thinking-us.anthropic.claude-3-7-sonnet-20250219-v1:0
|
||||
- model_name: gpt-4o
|
||||
litellm_params:
|
||||
model: bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0
|
||||
thinking: {"type": "enabled", "budget_tokens": 1024}
|
||||
max_tokens: 1080
|
||||
merge_reasoning_content_in_choices: true
|
||||
|
||||
model: gpt-4o
|
||||
|
||||
|
||||
|
|
|
@ -139,12 +139,9 @@ from litellm.proxy.batches_endpoints.endpoints import router as batches_router
|
|||
|
||||
## Import All Misc routes here ##
|
||||
from litellm.proxy.caching_routes import router as caching_router
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
from litellm.proxy.common_utils.admin_ui_utils import html_form
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
get_logging_caching_headers,
|
||||
get_remaining_tokens_and_requests_from_request_data,
|
||||
initialize_callbacks_on_proxy,
|
||||
)
|
||||
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
|
||||
from litellm.proxy.common_utils.debug_utils import init_verbose_loggers
|
||||
from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
|
@ -236,6 +233,7 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
|||
router as pass_through_router,
|
||||
)
|
||||
from litellm.proxy.rerank_endpoints.endpoints import router as rerank_router
|
||||
from litellm.proxy.response_api_endpoints.endpoints import router as response_router
|
||||
from litellm.proxy.route_llm_request import route_request
|
||||
from litellm.proxy.spend_tracking.spend_management_endpoints import (
|
||||
router as spend_management_router,
|
||||
|
@ -783,69 +781,6 @@ db_writer_client: Optional[AsyncHTTPHandler] = None
|
|||
### logger ###
|
||||
|
||||
|
||||
def get_custom_headers(
|
||||
*,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_id: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
cache_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
model_region: Optional[str] = None,
|
||||
response_cost: Optional[Union[float, str]] = None,
|
||||
hidden_params: Optional[dict] = None,
|
||||
fastest_response_batch_completion: Optional[bool] = None,
|
||||
request_data: Optional[dict] = {},
|
||||
timeout: Optional[Union[float, int, httpx.Timeout]] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
exclude_values = {"", None, "None"}
|
||||
hidden_params = hidden_params or {}
|
||||
headers = {
|
||||
"x-litellm-call-id": call_id,
|
||||
"x-litellm-model-id": model_id,
|
||||
"x-litellm-cache-key": cache_key,
|
||||
"x-litellm-model-api-base": api_base,
|
||||
"x-litellm-version": version,
|
||||
"x-litellm-model-region": model_region,
|
||||
"x-litellm-response-cost": str(response_cost),
|
||||
"x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit),
|
||||
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
|
||||
"x-litellm-key-max-budget": str(user_api_key_dict.max_budget),
|
||||
"x-litellm-key-spend": str(user_api_key_dict.spend),
|
||||
"x-litellm-response-duration-ms": str(hidden_params.get("_response_ms", None)),
|
||||
"x-litellm-overhead-duration-ms": str(
|
||||
hidden_params.get("litellm_overhead_time_ms", None)
|
||||
),
|
||||
"x-litellm-fastest_response_batch_completion": (
|
||||
str(fastest_response_batch_completion)
|
||||
if fastest_response_batch_completion is not None
|
||||
else None
|
||||
),
|
||||
"x-litellm-timeout": str(timeout) if timeout is not None else None,
|
||||
**{k: str(v) for k, v in kwargs.items()},
|
||||
}
|
||||
if request_data:
|
||||
remaining_tokens_header = get_remaining_tokens_and_requests_from_request_data(
|
||||
request_data
|
||||
)
|
||||
headers.update(remaining_tokens_header)
|
||||
|
||||
logging_caching_headers = get_logging_caching_headers(request_data)
|
||||
if logging_caching_headers:
|
||||
headers.update(logging_caching_headers)
|
||||
|
||||
try:
|
||||
return {
|
||||
key: str(value)
|
||||
for key, value in headers.items()
|
||||
if value not in exclude_values
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error setting custom headers: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
async def check_request_disconnection(request: Request, llm_api_call_task):
|
||||
"""
|
||||
Asynchronously checks if the request is disconnected at regular intervals.
|
||||
|
@ -3518,169 +3453,28 @@ async def chat_completion( # noqa: PLR0915
|
|||
|
||||
"""
|
||||
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||
|
||||
data = {}
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
data = await _read_request_body(request=request)
|
||||
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||
try:
|
||||
data = await _read_request_body(request=request)
|
||||
verbose_proxy_logger.debug(
|
||||
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
|
||||
)
|
||||
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
return await base_llm_response_processor.base_process_llm_request(
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
data["model"] = (
|
||||
general_settings.get("completion_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
or model # for azure deployments
|
||||
or data.get("model", None) # default passed in http request
|
||||
)
|
||||
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
# override with user settings, these are params passed via cli
|
||||
if user_temperature:
|
||||
data["temperature"] = user_temperature
|
||||
if user_request_timeout:
|
||||
data["request_timeout"] = user_request_timeout
|
||||
if user_max_tokens:
|
||||
data["max_tokens"] = user_max_tokens
|
||||
if user_api_base:
|
||||
data["api_base"] = user_api_base
|
||||
|
||||
### MODEL ALIAS MAPPING ###
|
||||
# check if model name in model alias map
|
||||
# get the actual model name
|
||||
if isinstance(data["model"], str) and data["model"] in litellm.model_alias_map:
|
||||
data["model"] = litellm.model_alias_map[data["model"]]
|
||||
|
||||
### CALL HOOKS ### - modify/reject incoming data before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
||||
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
||||
)
|
||||
|
||||
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
|
||||
## IMPORTANT Note: - initialize this before running pre-call checks. Ensures we log rejected requests to langfuse.
|
||||
data["litellm_call_id"] = request.headers.get(
|
||||
"x-litellm-call-id", str(uuid.uuid4())
|
||||
)
|
||||
logging_obj, data = litellm.utils.function_setup(
|
||||
original_function="acompletion",
|
||||
rules_obj=litellm.utils.Rules(),
|
||||
start_time=datetime.now(),
|
||||
**data,
|
||||
)
|
||||
|
||||
data["litellm_logging_obj"] = logging_obj
|
||||
|
||||
tasks = []
|
||||
tasks.append(
|
||||
proxy_logging_obj.during_call_hook(
|
||||
data=data,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_type="completion",
|
||||
)
|
||||
)
|
||||
|
||||
### ROUTE THE REQUEST ###
|
||||
# Do not change this - it should be a constant time fetch - ALWAYS
|
||||
llm_call = await route_request(
|
||||
data=data,
|
||||
route_type="acompletion",
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
llm_router=llm_router,
|
||||
general_settings=general_settings,
|
||||
proxy_config=proxy_config,
|
||||
select_data_generator=select_data_generator,
|
||||
model=model,
|
||||
user_model=user_model,
|
||||
user_temperature=user_temperature,
|
||||
user_request_timeout=user_request_timeout,
|
||||
user_max_tokens=user_max_tokens,
|
||||
user_api_base=user_api_base,
|
||||
version=version,
|
||||
)
|
||||
tasks.append(llm_call)
|
||||
|
||||
# wait for call to end
|
||||
llm_responses = asyncio.gather(
|
||||
*tasks
|
||||
) # run the moderation check in parallel to the actual llm api call
|
||||
|
||||
responses = await llm_responses
|
||||
|
||||
response = responses[1]
|
||||
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
response_cost = hidden_params.get("response_cost", None) or ""
|
||||
fastest_response_batch_completion = hidden_params.get(
|
||||
"fastest_response_batch_completion", None
|
||||
)
|
||||
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
|
||||
|
||||
# Post Call Processing
|
||||
if llm_router is not None:
|
||||
data["deployment"] = llm_router.get_deployment(model_id=model_id)
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.update_request_status(
|
||||
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||
)
|
||||
)
|
||||
if (
|
||||
"stream" in data and data["stream"] is True
|
||||
): # use generate_responses to stream responses
|
||||
custom_headers = get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_id=logging_obj.litellm_call_id,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
response_cost=response_cost,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
**additional_headers,
|
||||
)
|
||||
selected_data_generator = select_data_generator(
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=data,
|
||||
)
|
||||
return StreamingResponse(
|
||||
selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
headers=custom_headers,
|
||||
)
|
||||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
|
||||
hidden_params = (
|
||||
getattr(response, "_hidden_params", {}) or {}
|
||||
) # get any updated response headers
|
||||
additional_headers = hidden_params.get("additional_headers", {}) or {}
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_id=logging_obj.litellm_call_id,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
response_cost=response_cost,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
**additional_headers,
|
||||
)
|
||||
)
|
||||
await check_response_size_is_safe(response=response)
|
||||
|
||||
return response
|
||||
except RejectedRequestError as e:
|
||||
_data = e.request_data
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
|
@ -3715,55 +3509,10 @@ async def chat_completion( # noqa: PLR0915
|
|||
_chat_response.usage = _usage # type: ignore
|
||||
return _chat_response
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"litellm.proxy.proxy_server.chat_completion(): Exception occured - {str(e)}"
|
||||
)
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
litellm_debug_info = getattr(e, "litellm_debug_info", "")
|
||||
verbose_proxy_logger.debug(
|
||||
"\033[1;31mAn error occurred: %s %s\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`",
|
||||
e,
|
||||
litellm_debug_info,
|
||||
)
|
||||
|
||||
timeout = getattr(
|
||||
e, "timeout", None
|
||||
) # returns the timeout set by the wrapper. Used for testing if model-specific timeout are set correctly
|
||||
_litellm_logging_obj: Optional[LiteLLMLoggingObj] = data.get(
|
||||
"litellm_logging_obj", None
|
||||
)
|
||||
custom_headers = get_custom_headers(
|
||||
raise await base_llm_response_processor._handle_llm_api_exception(
|
||||
e=e,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_id=(
|
||||
_litellm_logging_obj.litellm_call_id if _litellm_logging_obj else None
|
||||
),
|
||||
version=version,
|
||||
response_cost=0,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
timeout=timeout,
|
||||
)
|
||||
headers = getattr(e, "headers", {}) or {}
|
||||
headers.update(custom_headers)
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", str(e)),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
headers=headers,
|
||||
)
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
openai_code=getattr(e, "code", None),
|
||||
code=getattr(e, "status_code", 500),
|
||||
headers=headers,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
|
||||
|
@ -3880,7 +3629,7 @@ async def completion( # noqa: PLR0915
|
|||
if (
|
||||
"stream" in data and data["stream"] is True
|
||||
): # use generate_responses to stream responses
|
||||
custom_headers = get_custom_headers(
|
||||
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_id=litellm_call_id,
|
||||
model_id=model_id,
|
||||
|
@ -3908,7 +3657,7 @@ async def completion( # noqa: PLR0915
|
|||
)
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_id=litellm_call_id,
|
||||
model_id=model_id,
|
||||
|
@ -4139,7 +3888,7 @@ async def embeddings( # noqa: PLR0915
|
|||
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -4267,7 +4016,7 @@ async def image_generation(
|
|||
litellm_call_id = hidden_params.get("litellm_call_id", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -4388,7 +4137,7 @@ async def audio_speech(
|
|||
async for chunk in _generator:
|
||||
yield chunk
|
||||
|
||||
custom_headers = get_custom_headers(
|
||||
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -4529,7 +4278,7 @@ async def audio_transcriptions(
|
|||
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -4681,7 +4430,7 @@ async def get_assistants(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -4780,7 +4529,7 @@ async def create_assistant(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -4877,7 +4626,7 @@ async def delete_assistant(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -4974,7 +4723,7 @@ async def create_threads(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -5070,7 +4819,7 @@ async def get_thread(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -5169,7 +4918,7 @@ async def add_messages(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -5264,7 +5013,7 @@ async def get_messages(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -5373,7 +5122,7 @@ async def run_thread(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -5496,7 +5245,7 @@ async def moderations(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
@ -8640,6 +8389,7 @@ async def get_routes():
|
|||
|
||||
|
||||
app.include_router(router)
|
||||
app.include_router(response_router)
|
||||
app.include_router(batches_router)
|
||||
app.include_router(rerank_router)
|
||||
app.include_router(fine_tuning_router)
|
||||
|
|
|
@ -7,10 +7,12 @@ from fastapi.responses import ORJSONResponse
|
|||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
|
||||
router = APIRouter()
|
||||
import asyncio
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v2/rerank",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
|
@ -37,7 +39,6 @@ async def rerank(
|
|||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
|
@ -89,7 +90,7 @@ async def rerank(
|
|||
api_base = hidden_params.get("api_base", None) or ""
|
||||
additional_headers = hidden_params.get("additional_headers", None) or {}
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
ProxyBaseLLMRequestProcessing.get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
|
|
80
litellm/proxy/response_api_endpoints/endpoints.py
Normal file
80
litellm/proxy/response_api_endpoints/endpoints.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
from fastapi import APIRouter, Depends, Request, Response
|
||||
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth, user_api_key_auth
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/responses",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["responses"],
|
||||
)
|
||||
@router.post(
|
||||
"/responses",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["responses"],
|
||||
)
|
||||
async def responses_api(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Follows the OpenAI Responses API spec: https://platform.openai.com/docs/api-reference/responses
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:4000/v1/responses \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
"model": "gpt-4o",
|
||||
"input": "Tell me about AI"
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
_read_request_body,
|
||||
general_settings,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
select_data_generator,
|
||||
user_api_base,
|
||||
user_max_tokens,
|
||||
user_model,
|
||||
user_request_timeout,
|
||||
user_temperature,
|
||||
version,
|
||||
)
|
||||
|
||||
data = await _read_request_body(request=request)
|
||||
processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||
try:
|
||||
return await processor.base_process_llm_request(
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
route_type="aresponses",
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
llm_router=llm_router,
|
||||
general_settings=general_settings,
|
||||
proxy_config=proxy_config,
|
||||
select_data_generator=select_data_generator,
|
||||
model=None,
|
||||
user_model=user_model,
|
||||
user_temperature=user_temperature,
|
||||
user_request_timeout=user_request_timeout,
|
||||
user_max_tokens=user_max_tokens,
|
||||
user_api_base=user_api_base,
|
||||
version=version,
|
||||
)
|
||||
except Exception as e:
|
||||
raise await processor._handle_llm_api_exception(
|
||||
e=e,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
version=version,
|
||||
)
|
|
@ -21,6 +21,7 @@ ROUTE_ENDPOINT_MAPPING = {
|
|||
"atranscription": "/audio/transcriptions",
|
||||
"amoderation": "/moderations",
|
||||
"arerank": "/rerank",
|
||||
"aresponses": "/responses",
|
||||
}
|
||||
|
||||
|
||||
|
@ -45,6 +46,7 @@ async def route_request(
|
|||
"atranscription",
|
||||
"amoderation",
|
||||
"arerank",
|
||||
"aresponses",
|
||||
"_arealtime", # private function for realtime API
|
||||
],
|
||||
):
|
||||
|
|
|
@ -537,6 +537,7 @@ class ProxyLogging:
|
|||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"responses",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
|
|
|
@ -581,13 +581,7 @@ class Router:
|
|||
self._initialize_alerting()
|
||||
|
||||
self.initialize_assistants_endpoint()
|
||||
|
||||
self.amoderation = self.factory_function(
|
||||
litellm.amoderation, call_type="moderation"
|
||||
)
|
||||
self.aanthropic_messages = self.factory_function(
|
||||
litellm.anthropic_messages, call_type="anthropic_messages"
|
||||
)
|
||||
self.initialize_router_endpoints()
|
||||
|
||||
def discard(self):
|
||||
"""
|
||||
|
@ -653,6 +647,18 @@ class Router:
|
|||
self.aget_messages = self.factory_function(litellm.aget_messages)
|
||||
self.arun_thread = self.factory_function(litellm.arun_thread)
|
||||
|
||||
def initialize_router_endpoints(self):
|
||||
self.amoderation = self.factory_function(
|
||||
litellm.amoderation, call_type="moderation"
|
||||
)
|
||||
self.aanthropic_messages = self.factory_function(
|
||||
litellm.anthropic_messages, call_type="anthropic_messages"
|
||||
)
|
||||
self.aresponses = self.factory_function(
|
||||
litellm.aresponses, call_type="aresponses"
|
||||
)
|
||||
self.responses = self.factory_function(litellm.responses, call_type="responses")
|
||||
|
||||
def routing_strategy_init(
|
||||
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
|
||||
):
|
||||
|
@ -1080,17 +1086,22 @@ class Router:
|
|||
kwargs.setdefault("litellm_trace_id", str(uuid.uuid4()))
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
|
||||
def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None:
|
||||
def _update_kwargs_with_default_litellm_params(
|
||||
self, kwargs: dict, metadata_variable_name: Optional[str] = "metadata"
|
||||
) -> None:
|
||||
"""
|
||||
Adds default litellm params to kwargs, if set.
|
||||
"""
|
||||
self.default_litellm_params[metadata_variable_name] = (
|
||||
self.default_litellm_params.pop("metadata", {})
|
||||
)
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if (
|
||||
k not in kwargs and v is not None
|
||||
): # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
elif k == metadata_variable_name:
|
||||
kwargs[metadata_variable_name].update(v)
|
||||
|
||||
def _handle_clientside_credential(
|
||||
self, deployment: dict, kwargs: dict
|
||||
|
@ -1121,7 +1132,12 @@ class Router:
|
|||
) # add new deployment to router
|
||||
return deployment_pydantic_obj
|
||||
|
||||
def _update_kwargs_with_deployment(self, deployment: dict, kwargs: dict) -> None:
|
||||
def _update_kwargs_with_deployment(
|
||||
self,
|
||||
deployment: dict,
|
||||
kwargs: dict,
|
||||
function_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
2 jobs:
|
||||
- Adds selected deployment, model_info and api_base to kwargs["metadata"] (used for logging)
|
||||
|
@ -1138,7 +1154,10 @@ class Router:
|
|||
deployment_model_name = deployment_pydantic_obj.litellm_params.model
|
||||
deployment_api_base = deployment_pydantic_obj.litellm_params.api_base
|
||||
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
metadata_variable_name = _get_router_metadata_variable_name(
|
||||
function_name=function_name,
|
||||
)
|
||||
kwargs.setdefault(metadata_variable_name, {}).update(
|
||||
{
|
||||
"deployment": deployment_model_name,
|
||||
"model_info": model_info,
|
||||
|
@ -1151,7 +1170,9 @@ class Router:
|
|||
kwargs=kwargs, data=deployment["litellm_params"]
|
||||
)
|
||||
|
||||
self._update_kwargs_with_default_litellm_params(kwargs=kwargs)
|
||||
self._update_kwargs_with_default_litellm_params(
|
||||
kwargs=kwargs, metadata_variable_name=metadata_variable_name
|
||||
)
|
||||
|
||||
def _get_async_openai_model_client(self, deployment: dict, kwargs: dict):
|
||||
"""
|
||||
|
@ -2396,22 +2417,18 @@ class Router:
|
|||
messages=kwargs.get("messages", None),
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||||
self._update_kwargs_with_deployment(
|
||||
deployment=deployment, kwargs=kwargs, function_name="generic_api_call"
|
||||
)
|
||||
|
||||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
|
||||
model_client = self._get_async_openai_model_client(
|
||||
deployment=deployment,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
self.total_calls[model_name] += 1
|
||||
|
||||
response = original_function(
|
||||
**{
|
||||
**data,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
@ -2453,6 +2470,61 @@ class Router:
|
|||
self.fail_calls[model] += 1
|
||||
raise e
|
||||
|
||||
def _generic_api_call_with_fallbacks(
|
||||
self, model: str, original_function: Callable, **kwargs
|
||||
):
|
||||
"""
|
||||
Make a generic LLM API call through the router, this allows you to use retries/fallbacks with litellm router
|
||||
Args:
|
||||
model: The model to use
|
||||
original_function: The handler function to call (e.g., litellm.completion)
|
||||
**kwargs: Additional arguments to pass to the handler function
|
||||
Returns:
|
||||
The response from the handler function
|
||||
"""
|
||||
handler_name = original_function.__name__
|
||||
try:
|
||||
verbose_router_logger.debug(
|
||||
f"Inside _generic_api_call() - handler: {handler_name}, model: {model}; kwargs: {kwargs}"
|
||||
)
|
||||
deployment = self.get_available_deployment(
|
||||
model=model,
|
||||
messages=kwargs.get("messages", None),
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
self._update_kwargs_with_deployment(
|
||||
deployment=deployment, kwargs=kwargs, function_name="generic_api_call"
|
||||
)
|
||||
|
||||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
|
||||
self.total_calls[model_name] += 1
|
||||
|
||||
# Perform pre-call checks for routing strategy
|
||||
self.routing_strategy_pre_call_checks(deployment=deployment)
|
||||
|
||||
response = original_function(
|
||||
**{
|
||||
**data,
|
||||
"caching": self.cache_responses,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
verbose_router_logger.info(
|
||||
f"{handler_name}(model={model_name})\033[32m 200 OK\033[0m"
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_router_logger.info(
|
||||
f"{handler_name}(model={model})\033[31m Exception {str(e)}\033[0m"
|
||||
)
|
||||
if model is not None:
|
||||
self.fail_calls[model] += 1
|
||||
raise e
|
||||
|
||||
def embedding(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -2974,14 +3046,42 @@ class Router:
|
|||
self,
|
||||
original_function: Callable,
|
||||
call_type: Literal[
|
||||
"assistants", "moderation", "anthropic_messages"
|
||||
"assistants",
|
||||
"moderation",
|
||||
"anthropic_messages",
|
||||
"aresponses",
|
||||
"responses",
|
||||
] = "assistants",
|
||||
):
|
||||
async def new_function(
|
||||
"""
|
||||
Creates appropriate wrapper functions for different API call types.
|
||||
|
||||
Returns:
|
||||
- A synchronous function for synchronous call types
|
||||
- An asynchronous function for asynchronous call types
|
||||
"""
|
||||
# Handle synchronous call types
|
||||
if call_type == "responses":
|
||||
|
||||
def sync_wrapper(
|
||||
custom_llm_provider: Optional[
|
||||
Literal["openai", "azure", "anthropic"]
|
||||
] = None,
|
||||
client: Optional[Any] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return self._generic_api_call_with_fallbacks(
|
||||
original_function=original_function, **kwargs
|
||||
)
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
# Handle asynchronous call types
|
||||
async def async_wrapper(
|
||||
custom_llm_provider: Optional[
|
||||
Literal["openai", "azure", "anthropic"]
|
||||
] = None,
|
||||
client: Optional["AsyncOpenAI"] = None,
|
||||
client: Optional[Any] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if call_type == "assistants":
|
||||
|
@ -2992,18 +3092,16 @@ class Router:
|
|||
**kwargs,
|
||||
)
|
||||
elif call_type == "moderation":
|
||||
|
||||
return await self._pass_through_moderation_endpoint_factory( # type: ignore
|
||||
original_function=original_function,
|
||||
**kwargs,
|
||||
return await self._pass_through_moderation_endpoint_factory(
|
||||
original_function=original_function, **kwargs
|
||||
)
|
||||
elif call_type == "anthropic_messages":
|
||||
return await self._ageneric_api_call_with_fallbacks( # type: ignore
|
||||
elif call_type in ("anthropic_messages", "aresponses"):
|
||||
return await self._ageneric_api_call_with_fallbacks(
|
||||
original_function=original_function,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return new_function
|
||||
return async_wrapper
|
||||
|
||||
async def _pass_through_assistants_endpoint_factory(
|
||||
self,
|
||||
|
|
|
@ -56,7 +56,8 @@ def _get_router_metadata_variable_name(function_name) -> str:
|
|||
|
||||
For ALL other endpoints we call this "metadata
|
||||
"""
|
||||
if "batch" in function_name:
|
||||
ROUTER_METHODS_USING_LITELLM_METADATA = set(["batch", "generic_api_call"])
|
||||
if function_name in ROUTER_METHODS_USING_LITELLM_METADATA:
|
||||
return "litellm_metadata"
|
||||
else:
|
||||
return "metadata"
|
||||
|
|
|
@ -742,6 +742,9 @@ class BaseLiteLLMOpenAIResponseObject(BaseModel):
|
|||
def __contains__(self, key):
|
||||
return key in self.__dict__
|
||||
|
||||
def items(self):
|
||||
return self.__dict__.items()
|
||||
|
||||
|
||||
class OutputTokensDetails(BaseLiteLLMOpenAIResponseObject):
|
||||
reasoning_tokens: int
|
||||
|
|
|
@ -230,6 +230,7 @@ def test_select_azure_base_url_called(setup_mocks):
|
|||
"anthropic_messages",
|
||||
"add_message",
|
||||
"arun_thread_stream",
|
||||
"aresponses",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
|
|
@ -3,6 +3,7 @@ import sys
|
|||
import pytest
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
import litellm
|
||||
|
@ -16,6 +17,7 @@ from litellm.types.llms.openai import (
|
|||
ResponseAPIUsage,
|
||||
IncompleteDetails,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
|
||||
def validate_responses_api_response(response, final_chunk: bool = False):
|
||||
|
@ -503,3 +505,293 @@ async def test_openai_responses_api_streaming_validation(sync_mode):
|
|||
assert not missing_events, f"Missing required event types: {missing_events}"
|
||||
|
||||
print(f"Successfully validated all event types: {event_types_seen}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_responses_litellm_router(sync_mode):
|
||||
"""
|
||||
Test the OpenAI responses API with LiteLLM Router in both sync and async modes
|
||||
"""
|
||||
litellm._turn_on_debug()
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt4o-special-alias",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Call the handler
|
||||
if sync_mode:
|
||||
response = router.responses(
|
||||
model="gpt4o-special-alias",
|
||||
input="Hello, can you tell me a short joke?",
|
||||
max_output_tokens=100,
|
||||
)
|
||||
print("SYNC MODE RESPONSE=", response)
|
||||
else:
|
||||
response = await router.aresponses(
|
||||
model="gpt4o-special-alias",
|
||||
input="Hello, can you tell me a short joke?",
|
||||
max_output_tokens=100,
|
||||
)
|
||||
|
||||
print(
|
||||
f"Router {'sync' if sync_mode else 'async'} response=",
|
||||
json.dumps(response, indent=4, default=str),
|
||||
)
|
||||
|
||||
# Use the helper function to validate the response
|
||||
validate_responses_api_response(response, final_chunk=True)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_responses_litellm_router_streaming(sync_mode):
|
||||
"""
|
||||
Test the OpenAI responses API with streaming through LiteLLM Router
|
||||
"""
|
||||
litellm._turn_on_debug()
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt4o-special-alias",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
event_types_seen = set()
|
||||
|
||||
if sync_mode:
|
||||
response = router.responses(
|
||||
model="gpt4o-special-alias",
|
||||
input="Tell me about artificial intelligence in 2 sentences.",
|
||||
stream=True,
|
||||
)
|
||||
for event in response:
|
||||
print(f"Validating event type: {event.type}")
|
||||
validate_stream_event(event)
|
||||
event_types_seen.add(event.type)
|
||||
else:
|
||||
response = await router.aresponses(
|
||||
model="gpt4o-special-alias",
|
||||
input="Tell me about artificial intelligence in 2 sentences.",
|
||||
stream=True,
|
||||
)
|
||||
async for event in response:
|
||||
print(f"Validating event type: {event.type}")
|
||||
validate_stream_event(event)
|
||||
event_types_seen.add(event.type)
|
||||
|
||||
# At minimum, we should see these core event types
|
||||
required_events = {"response.created", "response.completed"}
|
||||
|
||||
missing_events = required_events - event_types_seen
|
||||
assert not missing_events, f"Missing required event types: {missing_events}"
|
||||
|
||||
print(f"Successfully validated all event types: {event_types_seen}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_responses_litellm_router_no_metadata():
|
||||
"""
|
||||
Test that metadata is not passed through when using the Router for responses API
|
||||
"""
|
||||
mock_response = {
|
||||
"id": "resp_123",
|
||||
"object": "response",
|
||||
"created_at": 1741476542,
|
||||
"status": "completed",
|
||||
"model": "gpt-4o",
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"id": "msg_123",
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "output_text", "text": "Hello world!", "annotations": []}
|
||||
],
|
||||
}
|
||||
],
|
||||
"parallel_tool_calls": True,
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
"output_tokens_details": {"reasoning_tokens": 0},
|
||||
},
|
||||
"text": {"format": {"type": "text"}},
|
||||
# Adding all required fields
|
||||
"error": None,
|
||||
"incomplete_details": None,
|
||||
"instructions": None,
|
||||
"metadata": {},
|
||||
"temperature": 1.0,
|
||||
"tool_choice": "auto",
|
||||
"tools": [],
|
||||
"top_p": 1.0,
|
||||
"max_output_tokens": None,
|
||||
"previous_response_id": None,
|
||||
"reasoning": {"effort": None, "summary": None},
|
||||
"truncation": "disabled",
|
||||
"user": None,
|
||||
}
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, json_data, status_code):
|
||||
self._json_data = json_data
|
||||
self.status_code = status_code
|
||||
self.text = str(json_data)
|
||||
|
||||
def json(self): # Changed from async to sync
|
||||
return self._json_data
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_post:
|
||||
# Configure the mock to return our response
|
||||
mock_post.return_value = MockResponse(mock_response, 200)
|
||||
|
||||
litellm._turn_on_debug()
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt4o-special-alias",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_key": "fake-key",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Call the handler with metadata
|
||||
await router.aresponses(
|
||||
model="gpt4o-special-alias",
|
||||
input="Hello, can you tell me a short joke?",
|
||||
)
|
||||
|
||||
# Check the request body
|
||||
request_body = mock_post.call_args.kwargs["data"]
|
||||
print("Request body:", json.dumps(request_body, indent=4))
|
||||
|
||||
loaded_request_body = json.loads(request_body)
|
||||
print("Loaded request body:", json.dumps(loaded_request_body, indent=4))
|
||||
|
||||
# Assert metadata is not in the request
|
||||
assert (
|
||||
loaded_request_body["metadata"] == None
|
||||
), "metadata should not be in the request body"
|
||||
mock_post.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_responses_litellm_router_with_metadata():
|
||||
"""
|
||||
Test that metadata is correctly passed through when explicitly provided to the Router for responses API
|
||||
"""
|
||||
test_metadata = {
|
||||
"user_id": "123",
|
||||
"conversation_id": "abc",
|
||||
"custom_field": "test_value",
|
||||
}
|
||||
|
||||
mock_response = {
|
||||
"id": "resp_123",
|
||||
"object": "response",
|
||||
"created_at": 1741476542,
|
||||
"status": "completed",
|
||||
"model": "gpt-4o",
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"id": "msg_123",
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "output_text", "text": "Hello world!", "annotations": []}
|
||||
],
|
||||
}
|
||||
],
|
||||
"parallel_tool_calls": True,
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
"output_tokens_details": {"reasoning_tokens": 0},
|
||||
},
|
||||
"text": {"format": {"type": "text"}},
|
||||
"error": None,
|
||||
"incomplete_details": None,
|
||||
"instructions": None,
|
||||
"metadata": test_metadata, # Include the test metadata in response
|
||||
"temperature": 1.0,
|
||||
"tool_choice": "auto",
|
||||
"tools": [],
|
||||
"top_p": 1.0,
|
||||
"max_output_tokens": None,
|
||||
"previous_response_id": None,
|
||||
"reasoning": {"effort": None, "summary": None},
|
||||
"truncation": "disabled",
|
||||
"user": None,
|
||||
}
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, json_data, status_code):
|
||||
self._json_data = json_data
|
||||
self.status_code = status_code
|
||||
self.text = str(json_data)
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_post:
|
||||
# Configure the mock to return our response
|
||||
mock_post.return_value = MockResponse(mock_response, 200)
|
||||
|
||||
litellm._turn_on_debug()
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt4o-special-alias",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_key": "fake-key",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Call the handler with metadata
|
||||
await router.aresponses(
|
||||
model="gpt4o-special-alias",
|
||||
input="Hello, can you tell me a short joke?",
|
||||
metadata=test_metadata,
|
||||
)
|
||||
|
||||
# Check the request body
|
||||
request_body = mock_post.call_args.kwargs["data"]
|
||||
loaded_request_body = json.loads(request_body)
|
||||
print("Request body:", json.dumps(loaded_request_body, indent=4))
|
||||
|
||||
# Assert metadata matches exactly what was passed
|
||||
assert (
|
||||
loaded_request_body["metadata"] == test_metadata
|
||||
), "metadata in request body should match what was passed"
|
||||
mock_post.assert_called_once()
|
||||
|
|
|
@ -315,14 +315,20 @@ async def test_router_with_empty_choices(model_list):
|
|||
assert response is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ageneric_api_call_with_fallbacks_basic():
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
def test_generic_api_call_with_fallbacks_basic(sync_mode):
|
||||
"""
|
||||
Test the _ageneric_api_call_with_fallbacks method with a basic successful call
|
||||
Test both the sync and async versions of generic_api_call_with_fallbacks with a basic successful call
|
||||
"""
|
||||
# Create a mock function that will be passed to _ageneric_api_call_with_fallbacks
|
||||
mock_function = AsyncMock()
|
||||
mock_function.__name__ = "test_function"
|
||||
# Create a mock function that will be passed to generic_api_call_with_fallbacks
|
||||
if sync_mode:
|
||||
from unittest.mock import Mock
|
||||
|
||||
mock_function = Mock()
|
||||
mock_function.__name__ = "test_function"
|
||||
else:
|
||||
mock_function = AsyncMock()
|
||||
mock_function.__name__ = "test_function"
|
||||
|
||||
# Create a mock response
|
||||
mock_response = {
|
||||
|
@ -347,13 +353,23 @@ async def test_ageneric_api_call_with_fallbacks_basic():
|
|||
]
|
||||
)
|
||||
|
||||
# Call the _ageneric_api_call_with_fallbacks method
|
||||
response = await router._ageneric_api_call_with_fallbacks(
|
||||
model="test-model-alias",
|
||||
original_function=mock_function,
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=100,
|
||||
)
|
||||
# Call the appropriate generic_api_call_with_fallbacks method
|
||||
if sync_mode:
|
||||
response = router._generic_api_call_with_fallbacks(
|
||||
model="test-model-alias",
|
||||
original_function=mock_function,
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=100,
|
||||
)
|
||||
else:
|
||||
response = asyncio.run(
|
||||
router._ageneric_api_call_with_fallbacks(
|
||||
model="test-model-alias",
|
||||
original_function=mock_function,
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=100,
|
||||
)
|
||||
)
|
||||
|
||||
# Verify the mock function was called
|
||||
mock_function.assert_called_once()
|
||||
|
@ -510,3 +526,36 @@ async def test__aadapter_completion():
|
|||
|
||||
# Verify async_routing_strategy_pre_call_checks was called
|
||||
router.async_routing_strategy_pre_call_checks.assert_called_once()
|
||||
|
||||
|
||||
def test_initialize_router_endpoints():
|
||||
"""
|
||||
Test that initialize_router_endpoints correctly sets up all router endpoints
|
||||
"""
|
||||
# Create a router with a basic model
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "test-model",
|
||||
"litellm_params": {
|
||||
"model": "anthropic/test-model",
|
||||
"api_key": "fake-api-key",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Explicitly call initialize_router_endpoints
|
||||
router.initialize_router_endpoints()
|
||||
|
||||
# Verify all expected endpoints are initialized
|
||||
assert hasattr(router, "amoderation")
|
||||
assert hasattr(router, "aanthropic_messages")
|
||||
assert hasattr(router, "aresponses")
|
||||
assert hasattr(router, "responses")
|
||||
|
||||
# Verify the endpoints are callable
|
||||
assert callable(router.amoderation)
|
||||
assert callable(router.aanthropic_messages)
|
||||
assert callable(router.aresponses)
|
||||
assert callable(router.responses)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue