mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
(Bug fix) - allow using Assistants GET, DELETE on /openai
pass through routes (#8818)
* test_openai_assistants_e2e_operations * test openai assistants pass through * fix GET request on pass through handler * _make_non_streaming_http_request * _is_assistants_api_request * test_openai_assistants_e2e_operations * test_openai_assistants_e2e_operations * openai_proxy_route * docs openai pass through * docs openai pass through * docs openai pass through * test pass through handler * Potential fix for code scanning alert no. 2240: Incomplete URL substring sanitization Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --------- Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
This commit is contained in:
parent
142276b468
commit
11fd5094c7
8 changed files with 572 additions and 84 deletions
95
docs/my-website/docs/pass_through/openai_passthrough.md
Normal file
95
docs/my-website/docs/pass_through/openai_passthrough.md
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
# OpenAI Passthrough
|
||||||
|
|
||||||
|
Pass-through endpoints for `/openai`
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
| Feature | Supported | Notes |
|
||||||
|
|-------|-------|-------|
|
||||||
|
| Cost Tracking | ❌ | Not supported |
|
||||||
|
| Logging | ✅ | Works across all integrations |
|
||||||
|
| Streaming | ✅ | Fully supported |
|
||||||
|
|
||||||
|
### When to use this?
|
||||||
|
|
||||||
|
- For 90% of your use cases, you should use the [native LiteLLM OpenAI Integration](https://docs.litellm.ai/docs/providers/openai) (`/chat/completions`, `/embeddings`, `/completions`, `/images`, `/batches`, etc.)
|
||||||
|
- Use this passthrough to call less popular or newer OpenAI endpoints that LiteLLM doesn't fully support yet, such as `/assistants`, `/threads`, `/vector_stores`
|
||||||
|
|
||||||
|
Simply replace `https://api.openai.com` with `LITELLM_PROXY_BASE_URL/openai`
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### Assistants API
|
||||||
|
|
||||||
|
#### Create OpenAI Client
|
||||||
|
|
||||||
|
Make sure you do the following:
|
||||||
|
- Point `base_url` to your `LITELLM_PROXY_BASE_URL/openai`
|
||||||
|
- Use your `LITELLM_API_KEY` as the `api_key`
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
|
||||||
|
client = openai.OpenAI(
|
||||||
|
base_url="http://0.0.0.0:4000/openai", # <your-proxy-url>/openai
|
||||||
|
api_key="sk-anything" # <your-proxy-api-key>
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Create an Assistant
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Create an assistant
|
||||||
|
assistant = client.beta.assistants.create(
|
||||||
|
name="Math Tutor",
|
||||||
|
instructions="You are a math tutor. Help solve equations.",
|
||||||
|
model="gpt-4o",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Create a Thread
|
||||||
|
```python
|
||||||
|
# Create a thread
|
||||||
|
thread = client.beta.threads.create()
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Add a Message to the Thread
|
||||||
|
```python
|
||||||
|
# Add a message
|
||||||
|
message = client.beta.threads.messages.create(
|
||||||
|
thread_id=thread.id,
|
||||||
|
role="user",
|
||||||
|
content="Solve 3x + 11 = 14",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Run the Assistant
|
||||||
|
```python
|
||||||
|
# Create a run to get the assistant's response
|
||||||
|
run = client.beta.threads.runs.create(
|
||||||
|
thread_id=thread.id,
|
||||||
|
assistant_id=assistant.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check run status
|
||||||
|
run_status = client.beta.threads.runs.retrieve(
|
||||||
|
thread_id=thread.id,
|
||||||
|
run_id=run.id
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Retrieve Messages
|
||||||
|
```python
|
||||||
|
# List messages after the run completes
|
||||||
|
messages = client.beta.threads.messages.list(
|
||||||
|
thread_id=thread.id
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Delete the Assistant
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Delete the assistant when done
|
||||||
|
client.beta.assistants.delete(assistant.id)
|
||||||
|
```
|
||||||
|
|
|
@ -303,6 +303,7 @@ const sidebars = {
|
||||||
"pass_through/vertex_ai",
|
"pass_through/vertex_ai",
|
||||||
"pass_through/google_ai_studio",
|
"pass_through/google_ai_studio",
|
||||||
"pass_through/cohere",
|
"pass_through/cohere",
|
||||||
|
"pass_through/openai_passthrough",
|
||||||
"pass_through/anthropic_completion",
|
"pass_through/anthropic_completion",
|
||||||
"pass_through/bedrock",
|
"pass_through/bedrock",
|
||||||
"pass_through/assembly_ai",
|
"pass_through/assembly_ai",
|
||||||
|
|
|
@ -240,3 +240,18 @@ class RouteChecks:
|
||||||
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
|
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
|
||||||
for allowed_route in allowed_routes
|
for allowed_route in allowed_routes
|
||||||
) # Check pattern match
|
) # Check pattern match
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_assistants_api_request(request: Request) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if `thread` or `assistant` is in the request path
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The request object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if `thread` or `assistant` is in the request path, False otherwise
|
||||||
|
"""
|
||||||
|
if "thread" in request.url.path or "assistant" in request.url.path:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
|
@ -17,6 +17,7 @@ from litellm.proxy._types import (
|
||||||
TeamCallbackMetadata,
|
TeamCallbackMetadata,
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.auth.route_checks import RouteChecks
|
||||||
from litellm.router import Router
|
from litellm.router import Router
|
||||||
from litellm.types.llms.anthropic import ANTHROPIC_API_HEADERS
|
from litellm.types.llms.anthropic import ANTHROPIC_API_HEADERS
|
||||||
from litellm.types.services import ServiceTypes
|
from litellm.types.services import ServiceTypes
|
||||||
|
@ -59,7 +60,7 @@ def _get_metadata_variable_name(request: Request) -> str:
|
||||||
|
|
||||||
For ALL other endpoints we call this "metadata
|
For ALL other endpoints we call this "metadata
|
||||||
"""
|
"""
|
||||||
if "thread" in request.url.path or "assistant" in request.url.path:
|
if RouteChecks._is_assistants_api_request(request):
|
||||||
return "litellm_metadata"
|
return "litellm_metadata"
|
||||||
if "batches" in request.url.path:
|
if "batches" in request.url.path:
|
||||||
return "litellm_metadata"
|
return "litellm_metadata"
|
||||||
|
|
|
@ -14,6 +14,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.constants import BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES
|
from litellm.constants import BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
|
from litellm.proxy.auth.route_checks import RouteChecks
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||||
create_pass_through_route,
|
create_pass_through_route,
|
||||||
|
@ -405,7 +406,7 @@ async def azure_proxy_route(
|
||||||
"Required 'AZURE_API_KEY' in environment to make pass-through calls to Azure."
|
"Required 'AZURE_API_KEY' in environment to make pass-through calls to Azure."
|
||||||
)
|
)
|
||||||
|
|
||||||
return await _base_openai_pass_through_handler(
|
return await BaseOpenAIPassThroughHandler._base_openai_pass_through_handler(
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
request=request,
|
request=request,
|
||||||
fastapi_response=fastapi_response,
|
fastapi_response=fastapi_response,
|
||||||
|
@ -431,7 +432,7 @@ async def openai_proxy_route(
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
base_target_url = "https://api.openai.com"
|
base_target_url = "https://api.openai.com/"
|
||||||
# Add or update query parameters
|
# Add or update query parameters
|
||||||
openai_api_key = passthrough_endpoint_router.get_credentials(
|
openai_api_key = passthrough_endpoint_router.get_credentials(
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider="openai",
|
||||||
|
@ -442,7 +443,7 @@ async def openai_proxy_route(
|
||||||
"Required 'OPENAI_API_KEY' in environment to make pass-through calls to OpenAI."
|
"Required 'OPENAI_API_KEY' in environment to make pass-through calls to OpenAI."
|
||||||
)
|
)
|
||||||
|
|
||||||
return await _base_openai_pass_through_handler(
|
return await BaseOpenAIPassThroughHandler._base_openai_pass_through_handler(
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
request=request,
|
request=request,
|
||||||
fastapi_response=fastapi_response,
|
fastapi_response=fastapi_response,
|
||||||
|
@ -452,44 +453,99 @@ async def openai_proxy_route(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _base_openai_pass_through_handler(
|
class BaseOpenAIPassThroughHandler:
|
||||||
endpoint: str,
|
@staticmethod
|
||||||
request: Request,
|
async def _base_openai_pass_through_handler(
|
||||||
fastapi_response: Response,
|
endpoint: str,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
request: Request,
|
||||||
base_target_url: str,
|
fastapi_response: Response,
|
||||||
api_key: str,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
):
|
base_target_url: str,
|
||||||
encoded_endpoint = httpx.URL(endpoint).path
|
api_key: str,
|
||||||
|
):
|
||||||
|
encoded_endpoint = httpx.URL(endpoint).path
|
||||||
|
|
||||||
# Ensure endpoint starts with '/' for proper URL construction
|
# Ensure endpoint starts with '/' for proper URL construction
|
||||||
if not encoded_endpoint.startswith("/"):
|
if not encoded_endpoint.startswith("/"):
|
||||||
encoded_endpoint = "/" + encoded_endpoint
|
encoded_endpoint = "/" + encoded_endpoint
|
||||||
|
|
||||||
# Construct the full target URL using httpx
|
# Ensure base_target_url is properly formatted for OpenAI
|
||||||
base_url = httpx.URL(base_target_url)
|
base_target_url = (
|
||||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
BaseOpenAIPassThroughHandler._append_v1_to_openai_passthrough_url(
|
||||||
|
base_target_url
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
## check for streaming
|
# Construct the full target URL by properly joining the base URL and endpoint path
|
||||||
is_streaming_request = False
|
base_url = httpx.URL(base_target_url)
|
||||||
if "stream" in str(updated_url):
|
updated_url = BaseOpenAIPassThroughHandler._join_url_paths(
|
||||||
is_streaming_request = True
|
base_url, encoded_endpoint
|
||||||
|
)
|
||||||
|
|
||||||
## CREATE PASS-THROUGH
|
## check for streaming
|
||||||
endpoint_func = create_pass_through_route(
|
is_streaming_request = False
|
||||||
endpoint=endpoint,
|
if "stream" in str(updated_url):
|
||||||
target=str(updated_url),
|
is_streaming_request = True
|
||||||
custom_headers={
|
|
||||||
|
## CREATE PASS-THROUGH
|
||||||
|
endpoint_func = create_pass_through_route(
|
||||||
|
endpoint=endpoint,
|
||||||
|
target=str(updated_url),
|
||||||
|
custom_headers=BaseOpenAIPassThroughHandler._assemble_headers(
|
||||||
|
api_key=api_key, request=request
|
||||||
|
),
|
||||||
|
) # dynamically construct pass-through endpoint based on incoming path
|
||||||
|
received_value = await endpoint_func(
|
||||||
|
request,
|
||||||
|
fastapi_response,
|
||||||
|
user_api_key_dict,
|
||||||
|
stream=is_streaming_request, # type: ignore
|
||||||
|
query_params=dict(request.query_params), # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
return received_value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _append_v1_to_openai_passthrough_url(base_url: str) -> str:
|
||||||
|
"""
|
||||||
|
Appends the /v1 path to the OpenAI base URL if it's the OpenAI API URL
|
||||||
|
"""
|
||||||
|
if base_url.rstrip("/") == "https://api.openai.com":
|
||||||
|
return "https://api.openai.com/v1"
|
||||||
|
return base_url
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _append_openai_beta_header(headers: dict, request: Request) -> dict:
|
||||||
|
"""
|
||||||
|
Appends the OpenAI-Beta header to the headers if the request is an OpenAI Assistants API request
|
||||||
|
"""
|
||||||
|
if RouteChecks._is_assistants_api_request(request) is True:
|
||||||
|
headers["OpenAI-Beta"] = "assistants=v2"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _assemble_headers(api_key: str, request: Request) -> dict:
|
||||||
|
base_headers = {
|
||||||
"authorization": "Bearer {}".format(api_key),
|
"authorization": "Bearer {}".format(api_key),
|
||||||
"api-key": "{}".format(api_key),
|
"api-key": "{}".format(api_key),
|
||||||
},
|
}
|
||||||
) # dynamically construct pass-through endpoint based on incoming path
|
return BaseOpenAIPassThroughHandler._append_openai_beta_header(
|
||||||
received_value = await endpoint_func(
|
headers=base_headers,
|
||||||
request,
|
request=request,
|
||||||
fastapi_response,
|
)
|
||||||
user_api_key_dict,
|
|
||||||
stream=is_streaming_request, # type: ignore
|
|
||||||
query_params=dict(request.query_params), # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
return received_value
|
@staticmethod
|
||||||
|
def _join_url_paths(base_url: httpx.URL, path: str) -> httpx.URL:
|
||||||
|
"""
|
||||||
|
Properly joins a base URL with a path, preserving any existing path in the base URL.
|
||||||
|
"""
|
||||||
|
if not base_url.path or base_url.path == "/":
|
||||||
|
# If base URL has no path, just use the new path
|
||||||
|
return base_url.copy_with(path=path)
|
||||||
|
|
||||||
|
# Join paths correctly by removing trailing/leading slashes as needed
|
||||||
|
base_path = base_url.path.rstrip("/")
|
||||||
|
clean_path = path.lstrip("/")
|
||||||
|
full_path = f"{base_path}/{clean_path}"
|
||||||
|
|
||||||
|
return base_url.copy_with(path=full_path)
|
||||||
|
|
|
@ -6,6 +6,7 @@ from datetime import datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from urllib.parse import urlparse
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
@ -259,48 +260,82 @@ async def chat_completion_pass_through_endpoint( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward_headers_from_request(
|
class HttpPassThroughEndpointHelpers:
|
||||||
request: Request,
|
@staticmethod
|
||||||
headers: dict,
|
def forward_headers_from_request(
|
||||||
forward_headers: Optional[bool] = False,
|
request: Request,
|
||||||
):
|
headers: dict,
|
||||||
"""
|
forward_headers: Optional[bool] = False,
|
||||||
Helper to forward headers from original request
|
):
|
||||||
"""
|
"""
|
||||||
if forward_headers is True:
|
Helper to forward headers from original request
|
||||||
request_headers = dict(request.headers)
|
"""
|
||||||
|
if forward_headers is True:
|
||||||
|
request_headers = dict(request.headers)
|
||||||
|
|
||||||
# Header We Should NOT forward
|
# Header We Should NOT forward
|
||||||
request_headers.pop("content-length", None)
|
request_headers.pop("content-length", None)
|
||||||
request_headers.pop("host", None)
|
request_headers.pop("host", None)
|
||||||
|
|
||||||
# Combine request headers with custom headers
|
# Combine request headers with custom headers
|
||||||
headers = {**request_headers, **headers}
|
headers = {**request_headers, **headers}
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_response_headers(
|
||||||
|
headers: httpx.Headers, litellm_call_id: Optional[str] = None
|
||||||
|
) -> dict:
|
||||||
|
excluded_headers = {"transfer-encoding", "content-encoding"}
|
||||||
|
|
||||||
def get_response_headers(
|
return_headers = {
|
||||||
headers: httpx.Headers, litellm_call_id: Optional[str] = None
|
key: value
|
||||||
) -> dict:
|
for key, value in headers.items()
|
||||||
excluded_headers = {"transfer-encoding", "content-encoding"}
|
if key.lower() not in excluded_headers
|
||||||
|
}
|
||||||
|
if litellm_call_id:
|
||||||
|
return_headers["x-litellm-call-id"] = litellm_call_id
|
||||||
|
|
||||||
return_headers = {
|
return return_headers
|
||||||
key: value
|
|
||||||
for key, value in headers.items()
|
|
||||||
if key.lower() not in excluded_headers
|
|
||||||
}
|
|
||||||
if litellm_call_id:
|
|
||||||
return_headers["x-litellm-call-id"] = litellm_call_id
|
|
||||||
|
|
||||||
return return_headers
|
@staticmethod
|
||||||
|
def get_endpoint_type(url: str) -> EndpointType:
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
if ("generateContent") in url or ("streamGenerateContent") in url:
|
||||||
|
return EndpointType.VERTEX_AI
|
||||||
|
elif parsed_url.hostname == "api.anthropic.com":
|
||||||
|
return EndpointType.ANTHROPIC
|
||||||
|
return EndpointType.GENERIC
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _make_non_streaming_http_request(
|
||||||
|
request: Request,
|
||||||
|
async_client: httpx.AsyncClient,
|
||||||
|
url: str,
|
||||||
|
headers: dict,
|
||||||
|
requested_query_params: Optional[dict] = None,
|
||||||
|
custom_body: Optional[dict] = None,
|
||||||
|
) -> httpx.Response:
|
||||||
|
"""
|
||||||
|
Make a non-streaming HTTP request
|
||||||
|
|
||||||
def get_endpoint_type(url: str) -> EndpointType:
|
If request is GET, don't include a JSON body
|
||||||
if ("generateContent") in url or ("streamGenerateContent") in url:
|
"""
|
||||||
return EndpointType.VERTEX_AI
|
if request.method == "GET":
|
||||||
elif ("api.anthropic.com") in url:
|
response = await async_client.request(
|
||||||
return EndpointType.ANTHROPIC
|
method=request.method,
|
||||||
return EndpointType.GENERIC
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
params=requested_query_params,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await async_client.request(
|
||||||
|
method=request.method,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
params=requested_query_params,
|
||||||
|
json=custom_body,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def pass_through_request( # noqa: PLR0915
|
async def pass_through_request( # noqa: PLR0915
|
||||||
|
@ -321,11 +356,13 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
|
|
||||||
url = httpx.URL(target)
|
url = httpx.URL(target)
|
||||||
headers = custom_headers
|
headers = custom_headers
|
||||||
headers = forward_headers_from_request(
|
headers = HttpPassThroughEndpointHelpers.forward_headers_from_request(
|
||||||
request=request, headers=headers, forward_headers=forward_headers
|
request=request, headers=headers, forward_headers=forward_headers
|
||||||
)
|
)
|
||||||
|
|
||||||
endpoint_type: EndpointType = get_endpoint_type(str(url))
|
endpoint_type: EndpointType = HttpPassThroughEndpointHelpers.get_endpoint_type(
|
||||||
|
str(url)
|
||||||
|
)
|
||||||
|
|
||||||
_parsed_body = None
|
_parsed_body = None
|
||||||
if custom_body:
|
if custom_body:
|
||||||
|
@ -442,7 +479,7 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
passthrough_success_handler_obj=pass_through_endpoint_logging,
|
passthrough_success_handler_obj=pass_through_endpoint_logging,
|
||||||
url_route=str(url),
|
url_route=str(url),
|
||||||
),
|
),
|
||||||
headers=get_response_headers(
|
headers=HttpPassThroughEndpointHelpers.get_response_headers(
|
||||||
headers=response.headers,
|
headers=response.headers,
|
||||||
litellm_call_id=litellm_call_id,
|
litellm_call_id=litellm_call_id,
|
||||||
),
|
),
|
||||||
|
@ -457,13 +494,21 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
|
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
|
||||||
|
|
||||||
response = await async_client.request(
|
if request.method == "GET":
|
||||||
method=request.method,
|
response = await async_client.request(
|
||||||
url=url,
|
method=request.method,
|
||||||
headers=headers,
|
url=url,
|
||||||
params=requested_query_params,
|
headers=headers,
|
||||||
json=_parsed_body,
|
params=requested_query_params,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
response = await async_client.request(
|
||||||
|
method=request.method,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
params=requested_query_params,
|
||||||
|
json=_parsed_body,
|
||||||
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.debug("response.headers= %s", response.headers)
|
verbose_proxy_logger.debug("response.headers= %s", response.headers)
|
||||||
|
|
||||||
|
@ -485,7 +530,7 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
passthrough_success_handler_obj=pass_through_endpoint_logging,
|
passthrough_success_handler_obj=pass_through_endpoint_logging,
|
||||||
url_route=str(url),
|
url_route=str(url),
|
||||||
),
|
),
|
||||||
headers=get_response_headers(
|
headers=HttpPassThroughEndpointHelpers.get_response_headers(
|
||||||
headers=response.headers,
|
headers=response.headers,
|
||||||
litellm_call_id=litellm_call_id,
|
litellm_call_id=litellm_call_id,
|
||||||
),
|
),
|
||||||
|
@ -525,7 +570,7 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
return Response(
|
return Response(
|
||||||
content=content,
|
content=content,
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
headers=get_response_headers(
|
headers=HttpPassThroughEndpointHelpers.get_response_headers(
|
||||||
headers=response.headers,
|
headers=response.headers,
|
||||||
litellm_call_id=litellm_call_id,
|
litellm_call_id=litellm_call_id,
|
||||||
),
|
),
|
||||||
|
|
|
@ -0,0 +1,194 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from fastapi import Request, Response
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||||
|
BaseOpenAIPassThroughHandler,
|
||||||
|
RouteChecks,
|
||||||
|
create_pass_through_route,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseOpenAIPassThroughHandler:
|
||||||
|
|
||||||
|
def test_append_v1_to_openai_passthrough_url(self):
|
||||||
|
print("\nTesting _append_v1_to_openai_passthrough_url method...")
|
||||||
|
|
||||||
|
# Test with OpenAI API URL
|
||||||
|
result1 = BaseOpenAIPassThroughHandler._append_v1_to_openai_passthrough_url(
|
||||||
|
"https://api.openai.com"
|
||||||
|
)
|
||||||
|
print(f"OpenAI URL: 'https://api.openai.com' → '{result1}'")
|
||||||
|
assert result1 == "https://api.openai.com/v1"
|
||||||
|
|
||||||
|
# Test with OpenAI API URL with trailing slash
|
||||||
|
result2 = BaseOpenAIPassThroughHandler._append_v1_to_openai_passthrough_url(
|
||||||
|
"https://api.openai.com/"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"OpenAI URL with trailing slash: 'https://api.openai.com/' → '{result2}'"
|
||||||
|
)
|
||||||
|
assert result2 == "https://api.openai.com/v1"
|
||||||
|
|
||||||
|
# Test with non-OpenAI URL
|
||||||
|
result3 = BaseOpenAIPassThroughHandler._append_v1_to_openai_passthrough_url(
|
||||||
|
"https://api.anthropic.com"
|
||||||
|
)
|
||||||
|
print(f"Non-OpenAI URL: 'https://api.anthropic.com' → '{result3}'")
|
||||||
|
assert result3 == "https://api.anthropic.com"
|
||||||
|
|
||||||
|
def test_join_url_paths(self):
|
||||||
|
print("\nTesting _join_url_paths method...")
|
||||||
|
|
||||||
|
# Test joining base URL with no path and a path
|
||||||
|
base_url = httpx.URL("https://api.example.com")
|
||||||
|
path = "/v1/chat/completions"
|
||||||
|
result = BaseOpenAIPassThroughHandler._join_url_paths(base_url, path)
|
||||||
|
print(f"Base URL with no path: '{base_url}' + '{path}' → '{result}'")
|
||||||
|
assert str(result) == "https://api.example.com/v1/chat/completions"
|
||||||
|
|
||||||
|
# Test joining base URL with path and another path
|
||||||
|
base_url = httpx.URL("https://api.example.com/v1")
|
||||||
|
path = "/chat/completions"
|
||||||
|
result = BaseOpenAIPassThroughHandler._join_url_paths(base_url, path)
|
||||||
|
print(f"Base URL with path: '{base_url}' + '{path}' → '{result}'")
|
||||||
|
assert str(result) == "https://api.example.com/v1/chat/completions"
|
||||||
|
|
||||||
|
# Test with path not starting with slash
|
||||||
|
base_url = httpx.URL("https://api.example.com/v1")
|
||||||
|
path = "chat/completions"
|
||||||
|
result = BaseOpenAIPassThroughHandler._join_url_paths(base_url, path)
|
||||||
|
print(f"Path without leading slash: '{base_url}' + '{path}' → '{result}'")
|
||||||
|
assert str(result) == "https://api.example.com/v1/chat/completions"
|
||||||
|
|
||||||
|
# Test with base URL having trailing slash
|
||||||
|
base_url = httpx.URL("https://api.example.com/v1/")
|
||||||
|
path = "/chat/completions"
|
||||||
|
result = BaseOpenAIPassThroughHandler._join_url_paths(base_url, path)
|
||||||
|
print(f"Base URL with trailing slash: '{base_url}' + '{path}' → '{result}'")
|
||||||
|
assert str(result) == "https://api.example.com/v1/chat/completions"
|
||||||
|
|
||||||
|
def test_append_openai_beta_header(self):
|
||||||
|
print("\nTesting _append_openai_beta_header method...")
|
||||||
|
|
||||||
|
# Create mock requests with different paths
|
||||||
|
assistants_request = MagicMock(spec=Request)
|
||||||
|
assistants_request.url = MagicMock()
|
||||||
|
assistants_request.url.path = "/v1/threads/thread_123456/messages"
|
||||||
|
|
||||||
|
non_assistants_request = MagicMock(spec=Request)
|
||||||
|
non_assistants_request.url = MagicMock()
|
||||||
|
non_assistants_request.url.path = "/v1/chat/completions"
|
||||||
|
|
||||||
|
headers = {"authorization": "Bearer test_key"}
|
||||||
|
|
||||||
|
# Test with assistants API request
|
||||||
|
result = BaseOpenAIPassThroughHandler._append_openai_beta_header(
|
||||||
|
headers, assistants_request
|
||||||
|
)
|
||||||
|
print(f"Assistants API request: Added header: {result}")
|
||||||
|
assert result["OpenAI-Beta"] == "assistants=v2"
|
||||||
|
|
||||||
|
# Test with non-assistants API request
|
||||||
|
headers = {"authorization": "Bearer test_key"}
|
||||||
|
result = BaseOpenAIPassThroughHandler._append_openai_beta_header(
|
||||||
|
headers, non_assistants_request
|
||||||
|
)
|
||||||
|
print(f"Non-assistants API request: Headers: {result}")
|
||||||
|
assert "OpenAI-Beta" not in result
|
||||||
|
|
||||||
|
# Test with assistant in the path
|
||||||
|
assistant_request = MagicMock(spec=Request)
|
||||||
|
assistant_request.url = MagicMock()
|
||||||
|
assistant_request.url.path = "/v1/assistants/asst_123456"
|
||||||
|
|
||||||
|
headers = {"authorization": "Bearer test_key"}
|
||||||
|
result = BaseOpenAIPassThroughHandler._append_openai_beta_header(
|
||||||
|
headers, assistant_request
|
||||||
|
)
|
||||||
|
print(f"Assistant API request: Added header: {result}")
|
||||||
|
assert result["OpenAI-Beta"] == "assistants=v2"
|
||||||
|
|
||||||
|
def test_assemble_headers(self):
|
||||||
|
print("\nTesting _assemble_headers method...")
|
||||||
|
|
||||||
|
# Mock request
|
||||||
|
mock_request = MagicMock(spec=Request)
|
||||||
|
api_key = "test_api_key"
|
||||||
|
|
||||||
|
# Patch the _append_openai_beta_header method to avoid testing it again
|
||||||
|
with patch.object(
|
||||||
|
BaseOpenAIPassThroughHandler,
|
||||||
|
"_append_openai_beta_header",
|
||||||
|
return_value={
|
||||||
|
"authorization": "Bearer test_api_key",
|
||||||
|
"api-key": "test_api_key",
|
||||||
|
"test-header": "value",
|
||||||
|
},
|
||||||
|
):
|
||||||
|
result = BaseOpenAIPassThroughHandler._assemble_headers(
|
||||||
|
api_key, mock_request
|
||||||
|
)
|
||||||
|
print(f"Assembled headers: {result}")
|
||||||
|
assert result["authorization"] == "Bearer test_api_key"
|
||||||
|
assert result["api-key"] == "test_api_key"
|
||||||
|
assert result["test-header"] == "value"
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||||
|
)
|
||||||
|
async def test_base_openai_pass_through_handler(self, mock_create_pass_through):
|
||||||
|
print("\nTesting _base_openai_pass_through_handler method...")
|
||||||
|
|
||||||
|
# Mock dependencies
|
||||||
|
mock_request = MagicMock(spec=Request)
|
||||||
|
mock_request.query_params = {"model": "gpt-4"}
|
||||||
|
mock_response = MagicMock(spec=Response)
|
||||||
|
mock_user_api_key_dict = MagicMock()
|
||||||
|
|
||||||
|
# Mock the endpoint function returned by create_pass_through_route
|
||||||
|
mock_endpoint_func = MagicMock()
|
||||||
|
mock_endpoint_func.return_value = {"result": "success"}
|
||||||
|
mock_create_pass_through.return_value = mock_endpoint_func
|
||||||
|
|
||||||
|
print("Testing standard endpoint pass-through...")
|
||||||
|
# Test with standard endpoint
|
||||||
|
result = await BaseOpenAIPassThroughHandler._base_openai_pass_through_handler(
|
||||||
|
endpoint="/chat/completions",
|
||||||
|
request=mock_request,
|
||||||
|
fastapi_response=mock_response,
|
||||||
|
user_api_key_dict=mock_user_api_key_dict,
|
||||||
|
base_target_url="https://api.openai.com",
|
||||||
|
api_key="test_api_key",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
print(f"Result from handler: {result}")
|
||||||
|
assert result == {"result": "success"}
|
||||||
|
|
||||||
|
# Verify create_pass_through_route was called with correct parameters
|
||||||
|
call_args = mock_create_pass_through.call_args[1]
|
||||||
|
print(
|
||||||
|
f"create_pass_through_route called with endpoint: {call_args['endpoint']}"
|
||||||
|
)
|
||||||
|
print(f"create_pass_through_route called with target: {call_args['target']}")
|
||||||
|
assert call_args["endpoint"] == "/chat/completions"
|
||||||
|
assert call_args["target"] == "https://api.openai.com/v1/chat/completions"
|
||||||
|
|
||||||
|
# Verify endpoint_func was called with correct parameters
|
||||||
|
print("Verifying endpoint_func call parameters...")
|
||||||
|
call_kwargs = mock_endpoint_func.call_args[1]
|
||||||
|
print(f"stream parameter: {call_kwargs['stream']}")
|
||||||
|
print(f"query_params: {call_kwargs['query_params']}")
|
||||||
|
assert call_kwargs["stream"] is False
|
||||||
|
assert call_kwargs["query_params"] == {"model": "gpt-4"}
|
|
@ -0,0 +1,81 @@
|
||||||
|
import pytest
|
||||||
|
import openai
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
from typing_extensions import override
|
||||||
|
from openai import AssistantEventHandler
|
||||||
|
|
||||||
|
client = openai.OpenAI(base_url="http://0.0.0.0:4000/openai", api_key="sk-1234")
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_assistants_e2e_operations():
|
||||||
|
|
||||||
|
assistant = client.beta.assistants.create(
|
||||||
|
name="Math Tutor",
|
||||||
|
instructions="You are a personal math tutor. Write and run code to answer math questions.",
|
||||||
|
tools=[{"type": "code_interpreter"}],
|
||||||
|
model="gpt-4o",
|
||||||
|
)
|
||||||
|
print("assistant created", assistant)
|
||||||
|
|
||||||
|
get_assistant = client.beta.assistants.retrieve(assistant.id)
|
||||||
|
print(get_assistant)
|
||||||
|
|
||||||
|
delete_assistant = client.beta.assistants.delete(assistant.id)
|
||||||
|
print(delete_assistant)
|
||||||
|
|
||||||
|
|
||||||
|
class EventHandler(AssistantEventHandler):
|
||||||
|
@override
|
||||||
|
def on_text_created(self, text) -> None:
|
||||||
|
print(f"\nassistant > ", end="", flush=True)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_text_delta(self, delta, snapshot):
|
||||||
|
print(delta.value, end="", flush=True)
|
||||||
|
|
||||||
|
def on_tool_call_created(self, tool_call):
|
||||||
|
print(f"\nassistant > {tool_call.type}\n", flush=True)
|
||||||
|
|
||||||
|
def on_tool_call_delta(self, delta, snapshot):
|
||||||
|
if delta.type == "code_interpreter":
|
||||||
|
if delta.code_interpreter.input:
|
||||||
|
print(delta.code_interpreter.input, end="", flush=True)
|
||||||
|
if delta.code_interpreter.outputs:
|
||||||
|
print(f"\n\noutput >", flush=True)
|
||||||
|
for output in delta.code_interpreter.outputs:
|
||||||
|
if output.type == "logs":
|
||||||
|
print(f"\n{output.logs}", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_assistants_e2e_operations_stream():
|
||||||
|
|
||||||
|
assistant = client.beta.assistants.create(
|
||||||
|
name="Math Tutor",
|
||||||
|
instructions="You are a personal math tutor. Write and run code to answer math questions.",
|
||||||
|
tools=[{"type": "code_interpreter"}],
|
||||||
|
model="gpt-4o",
|
||||||
|
)
|
||||||
|
print("assistant created", assistant)
|
||||||
|
|
||||||
|
thread = client.beta.threads.create()
|
||||||
|
print("thread created", thread)
|
||||||
|
|
||||||
|
message = client.beta.threads.messages.create(
|
||||||
|
thread_id=thread.id,
|
||||||
|
role="user",
|
||||||
|
content="I need to solve the equation `3x + 11 = 14`. Can you help me?",
|
||||||
|
)
|
||||||
|
print("message created", message)
|
||||||
|
|
||||||
|
# Then, we use the `stream` SDK helper
|
||||||
|
# with the `EventHandler` class to create the Run
|
||||||
|
# and stream the response.
|
||||||
|
|
||||||
|
with client.beta.threads.runs.stream(
|
||||||
|
thread_id=thread.id,
|
||||||
|
assistant_id=assistant.id,
|
||||||
|
instructions="Please address the user as Jane Doe. The user has a premium account.",
|
||||||
|
event_handler=EventHandler(),
|
||||||
|
) as stream:
|
||||||
|
stream.until_done()
|
Loading…
Add table
Add a link
Reference in a new issue