(feat) Add support for using @google/generative-ai JS with LiteLLM Proxy (#6899)

* feat - allow using gemini js SDK with LiteLLM

* add auth for gemini_proxy_route

* basic local test for js

* test cost tagging gemini js requests

* add js sdk test for gemini with litellm

* add docs on gemini JS SDK

* run node.js tests

* fix google ai studio tests

* fix vertex js spend test
This commit is contained in:
Ishaan Jaff 2024-11-25 13:13:03 -08:00 committed by GitHub
parent f77bf49772
commit c60261c3bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 323 additions and 12 deletions

View file

@ -1191,6 +1191,7 @@ jobs:
-e DATABASE_URL=$PROXY_DATABASE_URL \ -e DATABASE_URL=$PROXY_DATABASE_URL \
-e LITELLM_MASTER_KEY="sk-1234" \ -e LITELLM_MASTER_KEY="sk-1234" \
-e OPENAI_API_KEY=$OPENAI_API_KEY \ -e OPENAI_API_KEY=$OPENAI_API_KEY \
-e GEMINI_API_KEY=$GEMINI_API_KEY \
-e ANTHROPIC_API_KEY=$ANTHROPIC_API_KEY \ -e ANTHROPIC_API_KEY=$ANTHROPIC_API_KEY \
-e LITELLM_LICENSE=$LITELLM_LICENSE \ -e LITELLM_LICENSE=$LITELLM_LICENSE \
--name my-app \ --name my-app \
@ -1228,12 +1229,13 @@ jobs:
name: Install Node.js dependencies name: Install Node.js dependencies
command: | command: |
npm install @google-cloud/vertexai npm install @google-cloud/vertexai
npm install @google/generative-ai
npm install --save-dev jest npm install --save-dev jest
- run: - run:
name: Run Vertex AI tests name: Run Vertex AI, Google AI Studio Node.js tests
command: | command: |
npx jest tests/pass_through_tests/test_vertex.test.js --verbose npx jest tests/pass_through_tests --verbose
no_output_timeout: 30m no_output_timeout: 30m
- run: - run:
name: Run tests name: Run tests

View file

@ -1,12 +1,21 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Google AI Studio SDK # Google AI Studio SDK
Pass-through endpoints for Google AI Studio - call provider-specific endpoint, in native format (no translation). Pass-through endpoints for Google AI Studio - call provider-specific endpoint, in native format (no translation).
Just replace `https://generativelanguage.googleapis.com` with `LITELLM_PROXY_BASE_URL/gemini` 🚀 Just replace `https://generativelanguage.googleapis.com` with `LITELLM_PROXY_BASE_URL/gemini`
#### **Example Usage** #### **Example Usage**
<Tabs>
<TabItem value="curl" label="curl">
```bash ```bash
http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:countTokens?key=sk-anything' \ curl 'http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:countTokens?key=sk-anything' \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
-d '{ -d '{
"contents": [{ "contents": [{
@ -17,6 +26,53 @@ http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:countTokens?key=sk-any
}' }'
``` ```
</TabItem>
<TabItem value="js" label="Google AI Node.js SDK">
```javascript
const { GoogleGenerativeAI } = require("@google/generative-ai");
const modelParams = {
model: 'gemini-pro',
};
const requestOptions = {
baseUrl: 'http://localhost:4000/gemini', // http://<proxy-base-url>/gemini
};
const genAI = new GoogleGenerativeAI("sk-1234"); // litellm proxy API key
const model = genAI.getGenerativeModel(modelParams, requestOptions);
async function main() {
try {
const result = await model.generateContent("Explain how AI works");
console.log(result.response.text());
} catch (error) {
console.error('Error:', error);
}
}
// For streaming responses
async function main_streaming() {
try {
const streamingResult = await model.generateContentStream("Explain how AI works");
for await (const chunk of streamingResult.stream) {
console.log('Stream chunk:', JSON.stringify(chunk));
}
const aggregatedResponse = await streamingResult.response;
console.log('Aggregated response:', JSON.stringify(aggregatedResponse));
} catch (error) {
console.error('Error:', error);
}
}
main();
// main_streaming();
```
</TabItem>
</Tabs>
Supports **ALL** Google AI Studio Endpoints (including streaming). Supports **ALL** Google AI Studio Endpoints (including streaming).
[**See All Google AI Studio Endpoints**](https://ai.google.dev/api) [**See All Google AI Studio Endpoints**](https://ai.google.dev/api)
@ -166,14 +222,14 @@ curl -X POST "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5
``` ```
## Advanced - Use with Virtual Keys ## Advanced
Pre-requisites Pre-requisites
- [Setup proxy with DB](../proxy/virtual_keys.md#setup) - [Setup proxy with DB](../proxy/virtual_keys.md#setup)
Use this, to avoid giving developers the raw Google AI Studio key, but still letting them use Google AI Studio endpoints. Use this, to avoid giving developers the raw Google AI Studio key, but still letting them use Google AI Studio endpoints.
### Usage ### Use with Virtual Keys
1. Setup environment 1. Setup environment
@ -220,4 +276,66 @@ http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:countTokens?key=sk-123
}] }]
}] }]
}' }'
``` ```
### Send `tags` in request headers
Use this if you want `tags` to be tracked in the LiteLLM DB and on logging callbacks.
Pass tags in request headers as a comma separated list. In the example below the following tags will be tracked
```
tags: ["gemini-js-sdk", "pass-through-endpoint"]
```
<Tabs>
<TabItem value="curl" label="curl">
```bash
curl 'http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:generateContent?key=sk-anything' \
-H 'Content-Type: application/json' \
-H 'tags: gemini-js-sdk,pass-through-endpoint' \
-d '{
"contents": [{
"parts":[{
"text": "The quick brown fox jumps over the lazy dog."
}]
}]
}'
```
</TabItem>
<TabItem value="js" label="Google AI Node.js SDK">
```javascript
const { GoogleGenerativeAI } = require("@google/generative-ai");
const modelParams = {
model: 'gemini-pro',
};
const requestOptions = {
baseUrl: 'http://localhost:4000/gemini', // http://<proxy-base-url>/gemini
customHeaders: {
"tags": "gemini-js-sdk,pass-through-endpoint"
}
};
const genAI = new GoogleGenerativeAI("sk-1234");
const model = genAI.getGenerativeModel(modelParams, requestOptions);
async function main() {
try {
const result = await model.generateContent("Explain how AI works");
console.log(result.response.text());
} catch (error) {
console.error('Error:', error);
}
}
main();
```
</TabItem>
</Tabs>

View file

@ -2111,6 +2111,7 @@ class SpecialHeaders(enum.Enum):
openai_authorization = "Authorization" openai_authorization = "Authorization"
azure_authorization = "API-Key" azure_authorization = "API-Key"
anthropic_authorization = "x-api-key" anthropic_authorization = "x-api-key"
google_ai_studio_authorization = "x-goog-api-key"
class LitellmDataForBackendLLMCall(TypedDict, total=False): class LitellmDataForBackendLLMCall(TypedDict, total=False):

View file

@ -95,6 +95,11 @@ anthropic_api_key_header = APIKeyHeader(
auto_error=False, auto_error=False,
description="If anthropic client used.", description="If anthropic client used.",
) )
google_ai_studio_api_key_header = APIKeyHeader(
name=SpecialHeaders.google_ai_studio_authorization.value,
auto_error=False,
description="If google ai studio client used.",
)
def _get_bearer_token( def _get_bearer_token(
@ -197,6 +202,9 @@ async def user_api_key_auth( # noqa: PLR0915
anthropic_api_key_header: Optional[str] = fastapi.Security( anthropic_api_key_header: Optional[str] = fastapi.Security(
anthropic_api_key_header anthropic_api_key_header
), ),
google_ai_studio_api_key_header: Optional[str] = fastapi.Security(
google_ai_studio_api_key_header
),
) -> UserAPIKeyAuth: ) -> UserAPIKeyAuth:
from litellm.proxy.proxy_server import ( from litellm.proxy.proxy_server import (
general_settings, general_settings,
@ -233,6 +241,8 @@ async def user_api_key_auth( # noqa: PLR0915
api_key = azure_api_key_header api_key = azure_api_key_header
elif isinstance(anthropic_api_key_header, str): elif isinstance(anthropic_api_key_header, str):
api_key = anthropic_api_key_header api_key = anthropic_api_key_header
elif isinstance(google_ai_studio_api_key_header, str):
api_key = google_ai_studio_api_key_header
elif pass_through_endpoints is not None: elif pass_through_endpoints is not None:
for endpoint in pass_through_endpoints: for endpoint in pass_through_endpoints:
if endpoint.get("path", "") == route: if endpoint.get("path", "") == route:

View file

@ -61,10 +61,12 @@ async def gemini_proxy_route(
fastapi_response: Response, fastapi_response: Response,
): ):
## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY ## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY
api_key = request.query_params.get("key") google_ai_studio_api_key = request.query_params.get("key") or request.headers.get(
"x-goog-api-key"
)
user_api_key_dict = await user_api_key_auth( user_api_key_dict = await user_api_key_auth(
request=request, api_key="Bearer {}".format(api_key) request=request, api_key=f"Bearer {google_ai_studio_api_key}"
) )
base_target_url = "https://generativelanguage.googleapis.com" base_target_url = "https://generativelanguage.googleapis.com"

View file

@ -0,0 +1,123 @@
const { GoogleGenerativeAI } = require("@google/generative-ai");
const fs = require('fs');
const path = require('path');
// Import fetch if the SDK uses it
const originalFetch = global.fetch || require('node-fetch');
let lastCallId;
// Monkey-patch the fetch used internally
global.fetch = async function patchedFetch(url, options) {
const response = await originalFetch(url, options);
// Store the call ID if it exists
lastCallId = response.headers.get('x-litellm-call-id');
return response;
};
describe('Gemini AI Tests', () => {
test('should successfully generate non-streaming content with tags', async () => {
const genAI = new GoogleGenerativeAI("sk-1234"); // litellm proxy API key
const requestOptions = {
baseUrl: 'http://127.0.0.1:4000/gemini',
customHeaders: {
"tags": "gemini-js-sdk,pass-through-endpoint"
}
};
const model = genAI.getGenerativeModel({
model: 'gemini-pro'
}, requestOptions);
const prompt = 'Say "hello test" and nothing else';
const result = await model.generateContent(prompt);
expect(result).toBeDefined();
// Use the captured callId
const callId = lastCallId;
console.log("Captured Call ID:", callId);
// Wait for spend to be logged
await new Promise(resolve => setTimeout(resolve, 15000));
// Check spend logs
const spendResponse = await fetch(
`http://127.0.0.1:4000/spend/logs?request_id=${callId}`,
{
headers: {
'Authorization': 'Bearer sk-1234'
}
}
);
const spendData = await spendResponse.json();
console.log("spendData", spendData)
expect(spendData).toBeDefined();
expect(spendData[0].request_id).toBe(callId);
expect(spendData[0].call_type).toBe('pass_through_endpoint');
expect(spendData[0].request_tags).toEqual(['gemini-js-sdk', 'pass-through-endpoint']);
expect(spendData[0].metadata).toHaveProperty('user_api_key');
expect(spendData[0].model).toContain('gemini');
expect(spendData[0].spend).toBeGreaterThan(0);
}, 25000);
test('should successfully generate streaming content with tags', async () => {
const genAI = new GoogleGenerativeAI("sk-1234"); // litellm proxy API key
const requestOptions = {
baseUrl: 'http://127.0.0.1:4000/gemini',
customHeaders: {
"tags": "gemini-js-sdk,pass-through-endpoint"
}
};
const model = genAI.getGenerativeModel({
model: 'gemini-pro'
}, requestOptions);
const prompt = 'Say "hello test" and nothing else';
const streamingResult = await model.generateContentStream(prompt);
expect(streamingResult).toBeDefined();
for await (const chunk of streamingResult.stream) {
console.log('stream chunk:', JSON.stringify(chunk));
expect(chunk).toBeDefined();
}
const aggregatedResponse = await streamingResult.response;
console.log('aggregated response:', JSON.stringify(aggregatedResponse));
expect(aggregatedResponse).toBeDefined();
// Use the captured callId
const callId = lastCallId;
console.log("Captured Call ID:", callId);
// Wait for spend to be logged
await new Promise(resolve => setTimeout(resolve, 15000));
// Check spend logs
const spendResponse = await fetch(
`http://127.0.0.1:4000/spend/logs?request_id=${callId}`,
{
headers: {
'Authorization': 'Bearer sk-1234'
}
}
);
const spendData = await spendResponse.json();
console.log("spendData", spendData)
expect(spendData).toBeDefined();
expect(spendData[0].request_id).toBe(callId);
expect(spendData[0].call_type).toBe('pass_through_endpoint');
expect(spendData[0].request_tags).toEqual(['gemini-js-sdk', 'pass-through-endpoint']);
expect(spendData[0].metadata).toHaveProperty('user_api_key');
expect(spendData[0].model).toContain('gemini');
expect(spendData[0].spend).toBeGreaterThan(0);
}, 25000);
});

View file

@ -0,0 +1,55 @@
const { GoogleGenerativeAI, ModelParams, RequestOptions } = require("@google/generative-ai");
const modelParams = {
model: 'gemini-pro',
};
const requestOptions = {
baseUrl: 'http://127.0.0.1:4000/gemini',
customHeaders: {
"tags": "gemini-js-sdk,gemini-pro"
}
};
const genAI = new GoogleGenerativeAI("sk-1234"); // litellm proxy API key
const model = genAI.getGenerativeModel(modelParams, requestOptions);
const testPrompt = "Explain how AI works";
async function main() {
console.log("making request")
try {
const result = await model.generateContent(testPrompt);
console.log(result.response.text());
} catch (error) {
console.error('Error details:', {
name: error.name,
message: error.message,
cause: error.cause,
// Check if there's a network error
isNetworkError: error instanceof TypeError && error.message === 'fetch failed'
});
// Check if the server is running
if (error instanceof TypeError && error.message === 'fetch failed') {
console.error('Make sure your local server is running at http://localhost:4000');
}
}
}
async function main_streaming() {
try {
const streamingResult = await model.generateContentStream(testPrompt);
for await (const item of streamingResult.stream) {
console.log('stream chunk: ', JSON.stringify(item));
}
const aggregatedResponse = await streamingResult.response;
console.log('aggregated response: ', JSON.stringify(aggregatedResponse));
} catch (error) {
console.error('Error details:', error);
}
}
// main();
main_streaming();

View file

@ -60,9 +60,9 @@ function loadVertexAiCredentials() {
} }
// Run credential loading before tests // Run credential loading before tests
// beforeAll(() => { beforeAll(() => {
// loadVertexAiCredentials(); loadVertexAiCredentials();
// }); });