Merge pull request #5260 from BerriAI/google_ai_studio_pass_through

Pass-through endpoints for Gemini - Google AI Studio
This commit is contained in:
Krish Dholakia 2024-08-17 13:51:51 -07:00 committed by GitHub
commit ff6ff133ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 479 additions and 31 deletions

View file

@ -0,0 +1,223 @@
# Google AI Studio (Pass-Through)
Pass-through endpoints for Google AI Studio - call provider-specific endpoint, in native format (no translation).
Just replace `https://generativelanguage.googleapis.com` with `LITELLM_PROXY_BASE_URL/gemini` 🚀
#### **Example Usage**
```bash
http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:countTokens?key=sk-anything' \
-H 'Content-Type: application/json' \
-d '{
"contents": [{
"parts":[{
"text": "The quick brown fox jumps over the lazy dog."
}]
}]
}'
```
Supports **ALL** Google AI Studio Endpoints (including streaming).
[**See All Google AI Studio Endpoints**](https://ai.google.dev/api)
## Quick Start
Let's call the Gemini [`/countTokens` endpoint](https://ai.google.dev/api/tokens#method:-models.counttokens)
1. Add Gemini API Key to your environment
```bash
export GEMINI_API_KEY=""
```
2. Start LiteLLM Proxy
```bash
litellm
# RUNNING on http://0.0.0.0:4000
```
3. Test it!
Let's call the Google AI Studio token counting endpoint
```bash
http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:countTokens?key=anything' \
-H 'Content-Type: application/json' \
-d '{
"contents": [{
"parts":[{
"text": "The quick brown fox jumps over the lazy dog."
}]
}]
}'
```
## Examples
Anything after `http://0.0.0.0:4000/gemini` is treated as a provider-specific route, and handled accordingly.
Key Changes:
| **Original Endpoint** | **Replace With** |
|------------------------------------------------------|-----------------------------------|
| `https://generativelanguage.googleapis.com` | `http://0.0.0.0:4000/gemini` (LITELLM_PROXY_BASE_URL="http://0.0.0.0:4000") |
| `key=$GOOGLE_API_KEY` | `key=anything` (use `key=LITELLM_VIRTUAL_KEY` if Virtual Keys are setup on proxy) |
### **Example 1: Counting tokens**
#### LiteLLM Proxy Call
```bash
curl http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:countTokens?key=anything \
-H 'Content-Type: application/json' \
-X POST \
-d '{
"contents": [{
"parts":[{
"text": "The quick brown fox jumps over the lazy dog."
}],
}],
}'
```
#### Direct Google AI Studio Call
```bash
curl https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:countTokens?key=$GOOGLE_API_KEY \
-H 'Content-Type: application/json' \
-X POST \
-d '{
"contents": [{
"parts":[{
"text": "The quick brown fox jumps over the lazy dog."
}],
}],
}'
```
### **Example 2: Generate content**
#### LiteLLM Proxy Call
```bash
curl "http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:generateContent?key=anything" \
-H 'Content-Type: application/json' \
-X POST \
-d '{
"contents": [{
"parts":[{"text": "Write a story about a magic backpack."}]
}]
}' 2> /dev/null
```
#### Direct Google AI Studio Call
```bash
curl "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=$GOOGLE_API_KEY" \
-H 'Content-Type: application/json' \
-X POST \
-d '{
"contents": [{
"parts":[{"text": "Write a story about a magic backpack."}]
}]
}' 2> /dev/null
```
### **Example 3: Caching**
```bash
curl -X POST "http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash-001:generateContent?key=anything" \
-H 'Content-Type: application/json' \
-d '{
"contents": [
{
"parts":[{
"text": "Please summarize this transcript"
}],
"role": "user"
},
],
"cachedContent": "'$CACHE_NAME'"
}'
```
#### Direct Google AI Studio Call
```bash
curl -X POST "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-001:generateContent?key=$GOOGLE_API_KEY" \
-H 'Content-Type: application/json' \
-d '{
"contents": [
{
"parts":[{
"text": "Please summarize this transcript"
}],
"role": "user"
},
],
"cachedContent": "'$CACHE_NAME'"
}'
```
## Advanced - Use with Virtual Keys
Pre-requisites
- [Setup proxy with DB](../proxy/virtual_keys.md#setup)
Use this, to avoid giving developers the raw Google AI Studio key, but still letting them use Google AI Studio endpoints.
### Usage
1. Setup environment
```bash
export DATABASE_URL=""
export LITELLM_MASTER_KEY=""
export GEMINI_API_KEY=""
```
```bash
litellm
# RUNNING on http://0.0.0.0:4000
```
2. Generate virtual key
```bash
curl -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{}'
```
Expected Response
```bash
{
...
"key": "sk-1234ewknldferwedojwojw"
}
```
3. Test it!
```bash
http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:countTokens?key=sk-1234ewknldferwedojwojw' \
-H 'Content-Type: application/json' \
-d '{
"contents": [{
"parts":[{
"text": "The quick brown fox jumps over the lazy dog."
}]
}]
}'
```

View file

@ -192,7 +192,8 @@ const sidebars = {
"batches", "batches",
"fine_tuning", "fine_tuning",
"anthropic_completion", "anthropic_completion",
"vertex_ai" "pass_through/vertex_ai",
"pass_through/google_ai_studio"
], ],
}, },
"scheduler", "scheduler",

View file

@ -412,7 +412,7 @@ def get_replicate_completion_pricing(completion_response=None, total_time=0.0):
def _select_model_name_for_cost_calc( def _select_model_name_for_cost_calc(
model: Optional[str], model: Optional[str],
completion_response: Union[BaseModel, dict], completion_response: Union[BaseModel, dict, str],
base_model: Optional[str] = None, base_model: Optional[str] = None,
custom_pricing: Optional[bool] = None, custom_pricing: Optional[bool] = None,
) -> Optional[str]: ) -> Optional[str]:
@ -428,7 +428,12 @@ def _select_model_name_for_cost_calc(
if base_model is not None: if base_model is not None:
return base_model return base_model
return_model = model or completion_response.get("model", "") # type: ignore return_model = model
if isinstance(completion_response, str):
return return_model
elif return_model is None:
return_model = completion_response.get("model", "") # type: ignore
if hasattr(completion_response, "_hidden_params"): if hasattr(completion_response, "_hidden_params"):
if ( if (
completion_response._hidden_params.get("model", None) is not None completion_response._hidden_params.get("model", None) is not None

View file

@ -274,6 +274,7 @@ class Logging:
headers = {} headers = {}
data = additional_args.get("complete_input_dict", {}) data = additional_args.get("complete_input_dict", {})
api_base = str(additional_args.get("api_base", "")) api_base = str(additional_args.get("api_base", ""))
query_params = additional_args.get("query_params", {})
if "key=" in api_base: if "key=" in api_base:
# Find the position of "key=" in the string # Find the position of "key=" in the string
key_index = api_base.find("key=") + 4 key_index = api_base.find("key=") + 4
@ -2362,7 +2363,7 @@ def get_standard_logging_object_payload(
return payload return payload
except Exception as e: except Exception as e:
verbose_logger.error( verbose_logger.exception(
"Error creating standard logging object - {}".format(str(e)) "Error creating standard logging object - {}".format(str(e))
) )
return None return None

View file

@ -3,7 +3,7 @@ import asyncio
import json import json
import traceback import traceback
from base64 import b64encode from base64 import b64encode
from typing import List, Optional from typing import AsyncIterable, List, Optional
import httpx import httpx
from fastapi import ( from fastapi import (
@ -267,12 +267,25 @@ def forward_headers_from_request(
return headers return headers
def get_response_headers(headers: httpx.Headers) -> dict:
excluded_headers = {"transfer-encoding", "content-encoding"}
return_headers = {
key: value
for key, value in headers.items()
if key.lower() not in excluded_headers
}
return return_headers
async def pass_through_request( async def pass_through_request(
request: Request, request: Request,
target: str, target: str,
custom_headers: dict, custom_headers: dict,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
forward_headers: Optional[bool] = False, forward_headers: Optional[bool] = False,
query_params: Optional[dict] = None,
stream: Optional[bool] = None,
): ):
try: try:
import time import time
@ -291,7 +304,7 @@ async def pass_through_request(
body_str = request_body.decode() body_str = request_body.decode()
try: try:
_parsed_body = ast.literal_eval(body_str) _parsed_body = ast.literal_eval(body_str)
except: except Exception:
_parsed_body = json.loads(body_str) _parsed_body = json.loads(body_str)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
@ -308,23 +321,9 @@ async def pass_through_request(
) )
async_client = httpx.AsyncClient() async_client = httpx.AsyncClient()
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=request.query_params,
json=_parsed_body,
)
if response.status_code >= 300:
raise HTTPException(status_code=response.status_code, detail=response.text)
content = await response.aread()
## LOG SUCCESS
start_time = time.time()
end_time = time.time()
# create logging object # create logging object
start_time = time.time()
logging_obj = Logging( logging_obj = Logging(
model="unknown", model="unknown",
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}], messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
@ -334,6 +333,7 @@ async def pass_through_request(
litellm_call_id=str(uuid.uuid4()), litellm_call_id=str(uuid.uuid4()),
function_id="1245", function_id="1245",
) )
# done for supporting 'parallel_request_limiter.py' with pass-through endpoints # done for supporting 'parallel_request_limiter.py' with pass-through endpoints
kwargs = { kwargs = {
"litellm_params": { "litellm_params": {
@ -354,6 +354,81 @@ async def pass_through_request(
call_type="pass_through_endpoint", call_type="pass_through_endpoint",
) )
# combine url with query params for logging
requested_query_params = query_params or request.query_params.__dict__
requested_query_params_str = "&".join(
f"{k}={v}" for k, v in requested_query_params.items()
)
if "?" in str(url):
logging_url = str(url) + "&" + requested_query_params_str
else:
logging_url = str(url) + "?" + requested_query_params_str
logging_obj.pre_call(
input=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
api_key="",
additional_args={
"complete_input_dict": _parsed_body,
"api_base": logging_url,
"headers": headers,
},
)
if stream:
req = async_client.build_request(
"POST",
url,
json=_parsed_body,
params=requested_query_params,
headers=headers,
)
response = await async_client.send(req, stream=stream)
# Create an async generator to yield the response content
async def stream_response() -> AsyncIterable[bytes]:
async for chunk in response.aiter_bytes():
yield chunk
return StreamingResponse(
stream_response(),
headers=get_response_headers(response.headers),
status_code=response.status_code,
)
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
json=_parsed_body,
)
if (
response.headers.get("content-type") is not None
and response.headers["content-type"] == "text/event-stream"
):
# streaming response
# Create an async generator to yield the response content
async def stream_response() -> AsyncIterable[bytes]:
async for chunk in response.aiter_bytes():
yield chunk
return StreamingResponse(
stream_response(),
headers=get_response_headers(response.headers),
status_code=response.status_code,
)
if response.status_code >= 300:
raise HTTPException(status_code=response.status_code, detail=response.text)
content = await response.aread()
## LOG SUCCESS
end_time = time.time()
await logging_obj.async_success_handler( await logging_obj.async_success_handler(
result="", result="",
start_time=start_time, start_time=start_time,
@ -361,17 +436,10 @@ async def pass_through_request(
cache_hit=False, cache_hit=False,
) )
excluded_headers = {"transfer-encoding", "content-encoding"}
headers = {
key: value
for key, value in response.headers.items()
if key.lower() not in excluded_headers
}
return Response( return Response(
content=content, content=content,
status_code=response.status_code, status_code=response.status_code,
headers=headers, headers=get_response_headers(response.headers),
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.exception( verbose_proxy_logger.exception(
@ -431,17 +499,23 @@ def create_pass_through_route(
except Exception: except Exception:
verbose_proxy_logger.debug("Defaulting to target being a url.") verbose_proxy_logger.debug("Defaulting to target being a url.")
async def endpoint_func( async def endpoint_func( # type: ignore
request: Request, request: Request,
fastapi_response: Response, fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
query_params: Optional[dict] = None,
stream: Optional[
bool
] = None, # if pass-through endpoint is a streaming request
): ):
return await pass_through_request( return await pass_through_request( # type: ignore
request=request, request=request,
target=target, target=target,
custom_headers=custom_headers or {}, custom_headers=custom_headers or {},
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
forward_headers=_forward_headers, forward_headers=_forward_headers,
query_params=query_params,
stream=stream,
) )
return endpoint_func return endpoint_func

View file

@ -230,6 +230,9 @@ from litellm.proxy.utils import (
send_email, send_email,
update_spend, update_spend,
) )
from litellm.proxy.vertex_ai_endpoints.google_ai_studio_endpoints import (
router as gemini_router,
)
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import router as vertex_router from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import router as vertex_router
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import set_default_vertex_config from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import set_default_vertex_config
from litellm.router import ( from litellm.router import (
@ -9734,6 +9737,7 @@ def cleanup_router_config_variables():
app.include_router(router) app.include_router(router)
app.include_router(fine_tuning_router) app.include_router(fine_tuning_router)
app.include_router(vertex_router) app.include_router(vertex_router)
app.include_router(gemini_router)
app.include_router(pass_through_router) app.include_router(pass_through_router)
app.include_router(health_router) app.include_router(health_router)
app.include_router(key_management_router) app.include_router(key_management_router)

View file

@ -3,3 +3,94 @@ What is this?
Google AI Studio Pass-Through Endpoints Google AI Studio Pass-Through Endpoints
""" """
"""
1. Create pass-through endpoints for any LITELLM_BASE_URL/gemini/<endpoint> map to https://generativelanguage.googleapis.com/<endpoint>
"""
import ast
import asyncio
import traceback
from datetime import datetime, timedelta, timezone
from typing import List, Optional
from urllib.parse import urlencode
import fastapi
import httpx
from fastapi import (
APIRouter,
Depends,
File,
Form,
Header,
HTTPException,
Request,
Response,
UploadFile,
status,
)
from starlette.datastructures import QueryParams
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.batches.main import FileObject
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
create_pass_through_route,
)
router = APIRouter()
default_vertex_config = None
@router.api_route("/gemini/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def gemini_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
):
## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY
api_key = request.query_params.get("key")
user_api_key_dict = await user_api_key_auth(
request=request, api_key="Bearer {}".format(api_key)
)
base_target_url = "https://generativelanguage.googleapis.com"
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
gemini_api_key = litellm.utils.get_secret(secret_name="GEMINI_API_KEY")
# Merge query parameters, giving precedence to those in updated_url
merged_params = dict(request.query_params)
merged_params.update({"key": gemini_api_key})
## check for streaming
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
query_params=merged_params,
stream=is_streaming_request,
)
return received_value

View file

@ -1166,3 +1166,52 @@ async def test_add_callback_via_key_litellm_pre_call_utils(prisma_client):
assert new_data["success_callback"] == ["langfuse"] assert new_data["success_callback"] == ["langfuse"]
assert "langfuse_public_key" in new_data assert "langfuse_public_key" in new_data
assert "langfuse_secret_key" in new_data assert "langfuse_secret_key" in new_data
@pytest.mark.asyncio
async def test_gemini_pass_through_endpoint():
from starlette.datastructures import URL
from litellm.proxy.vertex_ai_endpoints.google_ai_studio_endpoints import (
Request,
Response,
gemini_proxy_route,
)
body = b"""
{
"contents": [{
"parts":[{
"text": "The quick brown fox jumps over the lazy dog."
}]
}]
}
"""
# Construct the scope dictionary
scope = {
"type": "http",
"method": "POST",
"path": "/gemini/v1beta/models/gemini-1.5-flash:countTokens",
"query_string": b"key=sk-1234",
"headers": [
(b"content-type", b"application/json"),
],
}
# Create a new Request object
async def async_receive():
return {"type": "http.request", "body": body, "more_body": False}
request = Request(
scope=scope,
receive=async_receive,
)
resp = await gemini_proxy_route(
endpoint="v1beta/models/gemini-1.5-flash:countTokens?key=sk-1234",
request=request,
fastapi_response=Response(),
)
print(resp.body)