forked from phoenix/litellm-mirror
Merge pull request #5260 from BerriAI/google_ai_studio_pass_through
Pass-through endpoints for Gemini - Google AI Studio
This commit is contained in:
commit
ff6ff133ee
9 changed files with 479 additions and 31 deletions
223
docs/my-website/docs/pass_through/google_ai_studio.md
Normal file
223
docs/my-website/docs/pass_through/google_ai_studio.md
Normal file
|
@ -0,0 +1,223 @@
|
||||||
|
# Google AI Studio (Pass-Through)
|
||||||
|
|
||||||
|
Pass-through endpoints for Google AI Studio - call provider-specific endpoint, in native format (no translation).
|
||||||
|
|
||||||
|
Just replace `https://generativelanguage.googleapis.com` with `LITELLM_PROXY_BASE_URL/gemini` 🚀
|
||||||
|
|
||||||
|
#### **Example Usage**
|
||||||
|
```bash
|
||||||
|
http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:countTokens?key=sk-anything' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"contents": [{
|
||||||
|
"parts":[{
|
||||||
|
"text": "The quick brown fox jumps over the lazy dog."
|
||||||
|
}]
|
||||||
|
}]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Supports **ALL** Google AI Studio Endpoints (including streaming).
|
||||||
|
|
||||||
|
[**See All Google AI Studio Endpoints**](https://ai.google.dev/api)
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
Let's call the Gemini [`/countTokens` endpoint](https://ai.google.dev/api/tokens#method:-models.counttokens)
|
||||||
|
|
||||||
|
1. Add Gemini API Key to your environment
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export GEMINI_API_KEY=""
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start LiteLLM Proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm
|
||||||
|
|
||||||
|
# RUNNING on http://0.0.0.0:4000
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
Let's call the Google AI Studio token counting endpoint
|
||||||
|
|
||||||
|
```bash
|
||||||
|
http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:countTokens?key=anything' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"contents": [{
|
||||||
|
"parts":[{
|
||||||
|
"text": "The quick brown fox jumps over the lazy dog."
|
||||||
|
}]
|
||||||
|
}]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
Anything after `http://0.0.0.0:4000/gemini` is treated as a provider-specific route, and handled accordingly.
|
||||||
|
|
||||||
|
Key Changes:
|
||||||
|
|
||||||
|
| **Original Endpoint** | **Replace With** |
|
||||||
|
|------------------------------------------------------|-----------------------------------|
|
||||||
|
| `https://generativelanguage.googleapis.com` | `http://0.0.0.0:4000/gemini` (LITELLM_PROXY_BASE_URL="http://0.0.0.0:4000") |
|
||||||
|
| `key=$GOOGLE_API_KEY` | `key=anything` (use `key=LITELLM_VIRTUAL_KEY` if Virtual Keys are setup on proxy) |
|
||||||
|
|
||||||
|
|
||||||
|
### **Example 1: Counting tokens**
|
||||||
|
|
||||||
|
#### LiteLLM Proxy Call
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:countTokens?key=anything \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-X POST \
|
||||||
|
-d '{
|
||||||
|
"contents": [{
|
||||||
|
"parts":[{
|
||||||
|
"text": "The quick brown fox jumps over the lazy dog."
|
||||||
|
}],
|
||||||
|
}],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Direct Google AI Studio Call
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:countTokens?key=$GOOGLE_API_KEY \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-X POST \
|
||||||
|
-d '{
|
||||||
|
"contents": [{
|
||||||
|
"parts":[{
|
||||||
|
"text": "The quick brown fox jumps over the lazy dog."
|
||||||
|
}],
|
||||||
|
}],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### **Example 2: Generate content**
|
||||||
|
|
||||||
|
#### LiteLLM Proxy Call
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl "http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:generateContent?key=anything" \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-X POST \
|
||||||
|
-d '{
|
||||||
|
"contents": [{
|
||||||
|
"parts":[{"text": "Write a story about a magic backpack."}]
|
||||||
|
}]
|
||||||
|
}' 2> /dev/null
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Direct Google AI Studio Call
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=$GOOGLE_API_KEY" \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-X POST \
|
||||||
|
-d '{
|
||||||
|
"contents": [{
|
||||||
|
"parts":[{"text": "Write a story about a magic backpack."}]
|
||||||
|
}]
|
||||||
|
}' 2> /dev/null
|
||||||
|
```
|
||||||
|
|
||||||
|
### **Example 3: Caching**
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST "http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash-001:generateContent?key=anything" \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"parts":[{
|
||||||
|
"text": "Please summarize this transcript"
|
||||||
|
}],
|
||||||
|
"role": "user"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"cachedContent": "'$CACHE_NAME'"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Direct Google AI Studio Call
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-001:generateContent?key=$GOOGLE_API_KEY" \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"parts":[{
|
||||||
|
"text": "Please summarize this transcript"
|
||||||
|
}],
|
||||||
|
"role": "user"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"cachedContent": "'$CACHE_NAME'"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Advanced - Use with Virtual Keys
|
||||||
|
|
||||||
|
Pre-requisites
|
||||||
|
- [Setup proxy with DB](../proxy/virtual_keys.md#setup)
|
||||||
|
|
||||||
|
Use this, to avoid giving developers the raw Google AI Studio key, but still letting them use Google AI Studio endpoints.
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
1. Setup environment
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export DATABASE_URL=""
|
||||||
|
export LITELLM_MASTER_KEY=""
|
||||||
|
export GEMINI_API_KEY=""
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm
|
||||||
|
|
||||||
|
# RUNNING on http://0.0.0.0:4000
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Generate virtual key
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected Response
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{
|
||||||
|
...
|
||||||
|
"key": "sk-1234ewknldferwedojwojw"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:countTokens?key=sk-1234ewknldferwedojwojw' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"contents": [{
|
||||||
|
"parts":[{
|
||||||
|
"text": "The quick brown fox jumps over the lazy dog."
|
||||||
|
}]
|
||||||
|
}]
|
||||||
|
}'
|
||||||
|
```
|
|
@ -192,7 +192,8 @@ const sidebars = {
|
||||||
"batches",
|
"batches",
|
||||||
"fine_tuning",
|
"fine_tuning",
|
||||||
"anthropic_completion",
|
"anthropic_completion",
|
||||||
"vertex_ai"
|
"pass_through/vertex_ai",
|
||||||
|
"pass_through/google_ai_studio"
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
"scheduler",
|
"scheduler",
|
||||||
|
|
|
@ -412,7 +412,7 @@ def get_replicate_completion_pricing(completion_response=None, total_time=0.0):
|
||||||
|
|
||||||
def _select_model_name_for_cost_calc(
|
def _select_model_name_for_cost_calc(
|
||||||
model: Optional[str],
|
model: Optional[str],
|
||||||
completion_response: Union[BaseModel, dict],
|
completion_response: Union[BaseModel, dict, str],
|
||||||
base_model: Optional[str] = None,
|
base_model: Optional[str] = None,
|
||||||
custom_pricing: Optional[bool] = None,
|
custom_pricing: Optional[bool] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
|
@ -428,7 +428,12 @@ def _select_model_name_for_cost_calc(
|
||||||
if base_model is not None:
|
if base_model is not None:
|
||||||
return base_model
|
return base_model
|
||||||
|
|
||||||
return_model = model or completion_response.get("model", "") # type: ignore
|
return_model = model
|
||||||
|
if isinstance(completion_response, str):
|
||||||
|
return return_model
|
||||||
|
|
||||||
|
elif return_model is None:
|
||||||
|
return_model = completion_response.get("model", "") # type: ignore
|
||||||
if hasattr(completion_response, "_hidden_params"):
|
if hasattr(completion_response, "_hidden_params"):
|
||||||
if (
|
if (
|
||||||
completion_response._hidden_params.get("model", None) is not None
|
completion_response._hidden_params.get("model", None) is not None
|
||||||
|
|
|
@ -274,6 +274,7 @@ class Logging:
|
||||||
headers = {}
|
headers = {}
|
||||||
data = additional_args.get("complete_input_dict", {})
|
data = additional_args.get("complete_input_dict", {})
|
||||||
api_base = str(additional_args.get("api_base", ""))
|
api_base = str(additional_args.get("api_base", ""))
|
||||||
|
query_params = additional_args.get("query_params", {})
|
||||||
if "key=" in api_base:
|
if "key=" in api_base:
|
||||||
# Find the position of "key=" in the string
|
# Find the position of "key=" in the string
|
||||||
key_index = api_base.find("key=") + 4
|
key_index = api_base.find("key=") + 4
|
||||||
|
@ -2362,7 +2363,7 @@ def get_standard_logging_object_payload(
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.error(
|
verbose_logger.exception(
|
||||||
"Error creating standard logging object - {}".format(str(e))
|
"Error creating standard logging object - {}".format(str(e))
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -3,7 +3,7 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
from typing import List, Optional
|
from typing import AsyncIterable, List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
|
@ -267,12 +267,25 @@ def forward_headers_from_request(
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
def get_response_headers(headers: httpx.Headers) -> dict:
|
||||||
|
excluded_headers = {"transfer-encoding", "content-encoding"}
|
||||||
|
return_headers = {
|
||||||
|
key: value
|
||||||
|
for key, value in headers.items()
|
||||||
|
if key.lower() not in excluded_headers
|
||||||
|
}
|
||||||
|
|
||||||
|
return return_headers
|
||||||
|
|
||||||
|
|
||||||
async def pass_through_request(
|
async def pass_through_request(
|
||||||
request: Request,
|
request: Request,
|
||||||
target: str,
|
target: str,
|
||||||
custom_headers: dict,
|
custom_headers: dict,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
forward_headers: Optional[bool] = False,
|
forward_headers: Optional[bool] = False,
|
||||||
|
query_params: Optional[dict] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
import time
|
import time
|
||||||
|
@ -291,7 +304,7 @@ async def pass_through_request(
|
||||||
body_str = request_body.decode()
|
body_str = request_body.decode()
|
||||||
try:
|
try:
|
||||||
_parsed_body = ast.literal_eval(body_str)
|
_parsed_body = ast.literal_eval(body_str)
|
||||||
except:
|
except Exception:
|
||||||
_parsed_body = json.loads(body_str)
|
_parsed_body = json.loads(body_str)
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
|
@ -308,23 +321,9 @@ async def pass_through_request(
|
||||||
)
|
)
|
||||||
|
|
||||||
async_client = httpx.AsyncClient()
|
async_client = httpx.AsyncClient()
|
||||||
response = await async_client.request(
|
|
||||||
method=request.method,
|
|
||||||
url=url,
|
|
||||||
headers=headers,
|
|
||||||
params=request.query_params,
|
|
||||||
json=_parsed_body,
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code >= 300:
|
|
||||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
|
||||||
|
|
||||||
content = await response.aread()
|
|
||||||
|
|
||||||
## LOG SUCCESS
|
|
||||||
start_time = time.time()
|
|
||||||
end_time = time.time()
|
|
||||||
# create logging object
|
# create logging object
|
||||||
|
start_time = time.time()
|
||||||
logging_obj = Logging(
|
logging_obj = Logging(
|
||||||
model="unknown",
|
model="unknown",
|
||||||
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
||||||
|
@ -334,6 +333,7 @@ async def pass_through_request(
|
||||||
litellm_call_id=str(uuid.uuid4()),
|
litellm_call_id=str(uuid.uuid4()),
|
||||||
function_id="1245",
|
function_id="1245",
|
||||||
)
|
)
|
||||||
|
|
||||||
# done for supporting 'parallel_request_limiter.py' with pass-through endpoints
|
# done for supporting 'parallel_request_limiter.py' with pass-through endpoints
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
|
@ -354,6 +354,81 @@ async def pass_through_request(
|
||||||
call_type="pass_through_endpoint",
|
call_type="pass_through_endpoint",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# combine url with query params for logging
|
||||||
|
|
||||||
|
requested_query_params = query_params or request.query_params.__dict__
|
||||||
|
requested_query_params_str = "&".join(
|
||||||
|
f"{k}={v}" for k, v in requested_query_params.items()
|
||||||
|
)
|
||||||
|
|
||||||
|
if "?" in str(url):
|
||||||
|
logging_url = str(url) + "&" + requested_query_params_str
|
||||||
|
else:
|
||||||
|
logging_url = str(url) + "?" + requested_query_params_str
|
||||||
|
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": _parsed_body,
|
||||||
|
"api_base": logging_url,
|
||||||
|
"headers": headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
req = async_client.build_request(
|
||||||
|
"POST",
|
||||||
|
url,
|
||||||
|
json=_parsed_body,
|
||||||
|
params=requested_query_params,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
response = await async_client.send(req, stream=stream)
|
||||||
|
|
||||||
|
# Create an async generator to yield the response content
|
||||||
|
async def stream_response() -> AsyncIterable[bytes]:
|
||||||
|
async for chunk in response.aiter_bytes():
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_response(),
|
||||||
|
headers=get_response_headers(response.headers),
|
||||||
|
status_code=response.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await async_client.request(
|
||||||
|
method=request.method,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
params=requested_query_params,
|
||||||
|
json=_parsed_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
response.headers.get("content-type") is not None
|
||||||
|
and response.headers["content-type"] == "text/event-stream"
|
||||||
|
):
|
||||||
|
# streaming response
|
||||||
|
# Create an async generator to yield the response content
|
||||||
|
async def stream_response() -> AsyncIterable[bytes]:
|
||||||
|
async for chunk in response.aiter_bytes():
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_response(),
|
||||||
|
headers=get_response_headers(response.headers),
|
||||||
|
status_code=response.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code >= 300:
|
||||||
|
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||||
|
|
||||||
|
content = await response.aread()
|
||||||
|
|
||||||
|
## LOG SUCCESS
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
await logging_obj.async_success_handler(
|
await logging_obj.async_success_handler(
|
||||||
result="",
|
result="",
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -361,17 +436,10 @@ async def pass_through_request(
|
||||||
cache_hit=False,
|
cache_hit=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
excluded_headers = {"transfer-encoding", "content-encoding"}
|
|
||||||
headers = {
|
|
||||||
key: value
|
|
||||||
for key, value in response.headers.items()
|
|
||||||
if key.lower() not in excluded_headers
|
|
||||||
}
|
|
||||||
|
|
||||||
return Response(
|
return Response(
|
||||||
content=content,
|
content=content,
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
headers=headers,
|
headers=get_response_headers(response.headers),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.exception(
|
verbose_proxy_logger.exception(
|
||||||
|
@ -431,17 +499,23 @@ def create_pass_through_route(
|
||||||
except Exception:
|
except Exception:
|
||||||
verbose_proxy_logger.debug("Defaulting to target being a url.")
|
verbose_proxy_logger.debug("Defaulting to target being a url.")
|
||||||
|
|
||||||
async def endpoint_func(
|
async def endpoint_func( # type: ignore
|
||||||
request: Request,
|
request: Request,
|
||||||
fastapi_response: Response,
|
fastapi_response: Response,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
query_params: Optional[dict] = None,
|
||||||
|
stream: Optional[
|
||||||
|
bool
|
||||||
|
] = None, # if pass-through endpoint is a streaming request
|
||||||
):
|
):
|
||||||
return await pass_through_request(
|
return await pass_through_request( # type: ignore
|
||||||
request=request,
|
request=request,
|
||||||
target=target,
|
target=target,
|
||||||
custom_headers=custom_headers or {},
|
custom_headers=custom_headers or {},
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
forward_headers=_forward_headers,
|
forward_headers=_forward_headers,
|
||||||
|
query_params=query_params,
|
||||||
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
return endpoint_func
|
return endpoint_func
|
||||||
|
|
|
@ -230,6 +230,9 @@ from litellm.proxy.utils import (
|
||||||
send_email,
|
send_email,
|
||||||
update_spend,
|
update_spend,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.vertex_ai_endpoints.google_ai_studio_endpoints import (
|
||||||
|
router as gemini_router,
|
||||||
|
)
|
||||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import router as vertex_router
|
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import router as vertex_router
|
||||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import set_default_vertex_config
|
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import set_default_vertex_config
|
||||||
from litellm.router import (
|
from litellm.router import (
|
||||||
|
@ -9734,6 +9737,7 @@ def cleanup_router_config_variables():
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
app.include_router(fine_tuning_router)
|
app.include_router(fine_tuning_router)
|
||||||
app.include_router(vertex_router)
|
app.include_router(vertex_router)
|
||||||
|
app.include_router(gemini_router)
|
||||||
app.include_router(pass_through_router)
|
app.include_router(pass_through_router)
|
||||||
app.include_router(health_router)
|
app.include_router(health_router)
|
||||||
app.include_router(key_management_router)
|
app.include_router(key_management_router)
|
||||||
|
|
|
@ -3,3 +3,94 @@ What is this?
|
||||||
|
|
||||||
Google AI Studio Pass-Through Endpoints
|
Google AI Studio Pass-Through Endpoints
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
1. Create pass-through endpoints for any LITELLM_BASE_URL/gemini/<endpoint> map to https://generativelanguage.googleapis.com/<endpoint>
|
||||||
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import asyncio
|
||||||
|
import traceback
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import List, Optional
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
import httpx
|
||||||
|
from fastapi import (
|
||||||
|
APIRouter,
|
||||||
|
Depends,
|
||||||
|
File,
|
||||||
|
Form,
|
||||||
|
Header,
|
||||||
|
HTTPException,
|
||||||
|
Request,
|
||||||
|
Response,
|
||||||
|
UploadFile,
|
||||||
|
status,
|
||||||
|
)
|
||||||
|
from starlette.datastructures import QueryParams
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.batches.main import FileObject
|
||||||
|
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
||||||
|
from litellm.proxy._types import *
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||||
|
create_pass_through_route,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
default_vertex_config = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.api_route("/gemini/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||||
|
async def gemini_proxy_route(
|
||||||
|
endpoint: str,
|
||||||
|
request: Request,
|
||||||
|
fastapi_response: Response,
|
||||||
|
):
|
||||||
|
## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY
|
||||||
|
api_key = request.query_params.get("key")
|
||||||
|
|
||||||
|
user_api_key_dict = await user_api_key_auth(
|
||||||
|
request=request, api_key="Bearer {}".format(api_key)
|
||||||
|
)
|
||||||
|
|
||||||
|
base_target_url = "https://generativelanguage.googleapis.com"
|
||||||
|
encoded_endpoint = httpx.URL(endpoint).path
|
||||||
|
|
||||||
|
# Ensure endpoint starts with '/' for proper URL construction
|
||||||
|
if not encoded_endpoint.startswith("/"):
|
||||||
|
encoded_endpoint = "/" + encoded_endpoint
|
||||||
|
|
||||||
|
# Construct the full target URL using httpx
|
||||||
|
base_url = httpx.URL(base_target_url)
|
||||||
|
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||||
|
|
||||||
|
# Add or update query parameters
|
||||||
|
gemini_api_key = litellm.utils.get_secret(secret_name="GEMINI_API_KEY")
|
||||||
|
# Merge query parameters, giving precedence to those in updated_url
|
||||||
|
merged_params = dict(request.query_params)
|
||||||
|
merged_params.update({"key": gemini_api_key})
|
||||||
|
|
||||||
|
## check for streaming
|
||||||
|
is_streaming_request = False
|
||||||
|
if "stream" in str(updated_url):
|
||||||
|
is_streaming_request = True
|
||||||
|
|
||||||
|
## CREATE PASS-THROUGH
|
||||||
|
endpoint_func = create_pass_through_route(
|
||||||
|
endpoint=endpoint,
|
||||||
|
target=str(updated_url),
|
||||||
|
) # dynamically construct pass-through endpoint based on incoming path
|
||||||
|
received_value = await endpoint_func(
|
||||||
|
request,
|
||||||
|
fastapi_response,
|
||||||
|
user_api_key_dict,
|
||||||
|
query_params=merged_params,
|
||||||
|
stream=is_streaming_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
return received_value
|
||||||
|
|
|
@ -1166,3 +1166,52 @@ async def test_add_callback_via_key_litellm_pre_call_utils(prisma_client):
|
||||||
assert new_data["success_callback"] == ["langfuse"]
|
assert new_data["success_callback"] == ["langfuse"]
|
||||||
assert "langfuse_public_key" in new_data
|
assert "langfuse_public_key" in new_data
|
||||||
assert "langfuse_secret_key" in new_data
|
assert "langfuse_secret_key" in new_data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gemini_pass_through_endpoint():
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
|
from litellm.proxy.vertex_ai_endpoints.google_ai_studio_endpoints import (
|
||||||
|
Request,
|
||||||
|
Response,
|
||||||
|
gemini_proxy_route,
|
||||||
|
)
|
||||||
|
|
||||||
|
body = b"""
|
||||||
|
{
|
||||||
|
"contents": [{
|
||||||
|
"parts":[{
|
||||||
|
"text": "The quick brown fox jumps over the lazy dog."
|
||||||
|
}]
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Construct the scope dictionary
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"method": "POST",
|
||||||
|
"path": "/gemini/v1beta/models/gemini-1.5-flash:countTokens",
|
||||||
|
"query_string": b"key=sk-1234",
|
||||||
|
"headers": [
|
||||||
|
(b"content-type", b"application/json"),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a new Request object
|
||||||
|
async def async_receive():
|
||||||
|
return {"type": "http.request", "body": body, "more_body": False}
|
||||||
|
|
||||||
|
request = Request(
|
||||||
|
scope=scope,
|
||||||
|
receive=async_receive,
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = await gemini_proxy_route(
|
||||||
|
endpoint="v1beta/models/gemini-1.5-flash:countTokens?key=sk-1234",
|
||||||
|
request=request,
|
||||||
|
fastapi_response=Response(),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(resp.body)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue