mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
feat(google_ai_studio_endpoints.py): support pass-through endpoint for all google ai studio requests
New Feature
This commit is contained in:
parent
668ea6cbc7
commit
29bedae79f
6 changed files with 186 additions and 20 deletions
|
@ -412,7 +412,7 @@ def get_replicate_completion_pricing(completion_response=None, total_time=0.0):
|
|||
|
||||
def _select_model_name_for_cost_calc(
|
||||
model: Optional[str],
|
||||
completion_response: Union[BaseModel, dict],
|
||||
completion_response: Union[BaseModel, dict, str],
|
||||
base_model: Optional[str] = None,
|
||||
custom_pricing: Optional[bool] = None,
|
||||
) -> Optional[str]:
|
||||
|
@ -428,7 +428,12 @@ def _select_model_name_for_cost_calc(
|
|||
if base_model is not None:
|
||||
return base_model
|
||||
|
||||
return_model = model or completion_response.get("model", "") # type: ignore
|
||||
return_model = model
|
||||
if isinstance(completion_response, str):
|
||||
return return_model
|
||||
|
||||
elif return_model is None:
|
||||
return_model = completion_response.get("model", "") # type: ignore
|
||||
if hasattr(completion_response, "_hidden_params"):
|
||||
if (
|
||||
completion_response._hidden_params.get("model", None) is not None
|
||||
|
|
|
@ -274,6 +274,7 @@ class Logging:
|
|||
headers = {}
|
||||
data = additional_args.get("complete_input_dict", {})
|
||||
api_base = str(additional_args.get("api_base", ""))
|
||||
query_params = additional_args.get("query_params", {})
|
||||
if "key=" in api_base:
|
||||
# Find the position of "key=" in the string
|
||||
key_index = api_base.find("key=") + 4
|
||||
|
@ -2362,7 +2363,7 @@ def get_standard_logging_object_payload(
|
|||
|
||||
return payload
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
verbose_logger.exception(
|
||||
"Error creating standard logging object - {}".format(str(e))
|
||||
)
|
||||
return None
|
||||
|
|
|
@ -273,6 +273,7 @@ async def pass_through_request(
|
|||
custom_headers: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
forward_headers: Optional[bool] = False,
|
||||
query_params: Optional[dict] = None,
|
||||
):
|
||||
try:
|
||||
import time
|
||||
|
@ -308,23 +309,9 @@ async def pass_through_request(
|
|||
)
|
||||
|
||||
async_client = httpx.AsyncClient()
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=request.query_params,
|
||||
json=_parsed_body,
|
||||
)
|
||||
|
||||
if response.status_code >= 300:
|
||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||
|
||||
content = await response.aread()
|
||||
|
||||
## LOG SUCCESS
|
||||
start_time = time.time()
|
||||
end_time = time.time()
|
||||
# create logging object
|
||||
start_time = time.time()
|
||||
logging_obj = Logging(
|
||||
model="unknown",
|
||||
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
||||
|
@ -334,6 +321,7 @@ async def pass_through_request(
|
|||
litellm_call_id=str(uuid.uuid4()),
|
||||
function_id="1245",
|
||||
)
|
||||
|
||||
# done for supporting 'parallel_request_limiter.py' with pass-through endpoints
|
||||
kwargs = {
|
||||
"litellm_params": {
|
||||
|
@ -354,6 +342,44 @@ async def pass_through_request(
|
|||
call_type="pass_through_endpoint",
|
||||
)
|
||||
|
||||
# combine url with query params for logging
|
||||
|
||||
requested_query_params = query_params or request.query_params.__dict__
|
||||
requested_query_params_str = "&".join(
|
||||
f"{k}={v}" for k, v in requested_query_params.items()
|
||||
)
|
||||
|
||||
if "?" in str(url):
|
||||
logging_url = str(url) + "&" + requested_query_params_str
|
||||
else:
|
||||
logging_url = str(url) + "?" + requested_query_params_str
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": _parsed_body,
|
||||
"api_base": logging_url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=requested_query_params,
|
||||
json=_parsed_body,
|
||||
)
|
||||
|
||||
if response.status_code >= 300:
|
||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||
|
||||
content = await response.aread()
|
||||
|
||||
## LOG SUCCESS
|
||||
end_time = time.time()
|
||||
|
||||
await logging_obj.async_success_handler(
|
||||
result="",
|
||||
start_time=start_time,
|
||||
|
@ -431,17 +457,19 @@ def create_pass_through_route(
|
|||
except Exception:
|
||||
verbose_proxy_logger.debug("Defaulting to target being a url.")
|
||||
|
||||
async def endpoint_func(
|
||||
async def endpoint_func( # type: ignore
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
query_params: Optional[dict] = None,
|
||||
):
|
||||
return await pass_through_request(
|
||||
return await pass_through_request( # type: ignore
|
||||
request=request,
|
||||
target=target,
|
||||
custom_headers=custom_headers or {},
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
forward_headers=_forward_headers,
|
||||
query_params=query_params,
|
||||
)
|
||||
|
||||
return endpoint_func
|
||||
|
|
|
@ -227,6 +227,9 @@ from litellm.proxy.utils import (
|
|||
send_email,
|
||||
update_spend,
|
||||
)
|
||||
from litellm.proxy.vertex_ai_endpoints.google_ai_studio_endpoints import (
|
||||
router as gemini_router,
|
||||
)
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import router as vertex_router
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import set_default_vertex_config
|
||||
from litellm.router import (
|
||||
|
@ -9704,6 +9707,7 @@ def cleanup_router_config_variables():
|
|||
app.include_router(router)
|
||||
app.include_router(fine_tuning_router)
|
||||
app.include_router(vertex_router)
|
||||
app.include_router(gemini_router)
|
||||
app.include_router(pass_through_router)
|
||||
app.include_router(health_router)
|
||||
app.include_router(key_management_router)
|
||||
|
|
|
@ -3,3 +3,82 @@ What is this?
|
|||
|
||||
Google AI Studio Pass-Through Endpoints
|
||||
"""
|
||||
|
||||
"""
|
||||
1. Create pass-through endpoints for any LITELLM_BASE_URL/gemini/<endpoint> map to https://generativelanguage.googleapis.com/<endpoint>
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import fastapi
|
||||
import httpx
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
File,
|
||||
Form,
|
||||
Header,
|
||||
HTTPException,
|
||||
Request,
|
||||
Response,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from starlette.datastructures import QueryParams
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.batches.main import FileObject
|
||||
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||
create_pass_through_route,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
default_vertex_config = None
|
||||
|
||||
|
||||
@router.api_route("/gemini/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def gemini_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
):
|
||||
## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY
|
||||
api_key = request.query_params.get("key")
|
||||
|
||||
user_api_key_dict = await user_api_key_auth(
|
||||
request=request, api_key="Bearer {}".format(api_key)
|
||||
)
|
||||
|
||||
base_target_url = "https://generativelanguage.googleapis.com"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
gemini_api_key = litellm.utils.get_secret(secret_name="GEMINI_API_KEY")
|
||||
# Merge query parameters, giving precedence to those in updated_url
|
||||
merged_params = dict(request.query_params)
|
||||
merged_params.update({"key": gemini_api_key})
|
||||
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
return await endpoint_func(
|
||||
request, fastapi_response, user_api_key_dict, query_params=merged_params
|
||||
)
|
||||
|
|
|
@ -1166,3 +1166,52 @@ async def test_add_callback_via_key_litellm_pre_call_utils(prisma_client):
|
|||
assert new_data["success_callback"] == ["langfuse"]
|
||||
assert "langfuse_public_key" in new_data
|
||||
assert "langfuse_secret_key" in new_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_pass_through_endpoint():
|
||||
from starlette.datastructures import URL
|
||||
|
||||
from litellm.proxy.vertex_ai_endpoints.google_ai_studio_endpoints import (
|
||||
Request,
|
||||
Response,
|
||||
gemini_proxy_route,
|
||||
)
|
||||
|
||||
body = b"""
|
||||
{
|
||||
"contents": [{
|
||||
"parts":[{
|
||||
"text": "The quick brown fox jumps over the lazy dog."
|
||||
}]
|
||||
}]
|
||||
}
|
||||
"""
|
||||
|
||||
# Construct the scope dictionary
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"path": "/gemini/v1beta/models/gemini-1.5-flash:countTokens",
|
||||
"query_string": b"key=sk-1234",
|
||||
"headers": [
|
||||
(b"content-type", b"application/json"),
|
||||
],
|
||||
}
|
||||
|
||||
# Create a new Request object
|
||||
async def async_receive():
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
request = Request(
|
||||
scope=scope,
|
||||
receive=async_receive,
|
||||
)
|
||||
|
||||
resp = await gemini_proxy_route(
|
||||
endpoint="v1beta/models/gemini-1.5-flash:countTokens?key=sk-1234",
|
||||
request=request,
|
||||
fastapi_response=Response(),
|
||||
)
|
||||
|
||||
print(resp.body)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue