From c60261c3bc30f99bb01a1f8776de9d254c7efa1a Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 13:13:03 -0800 Subject: [PATCH] (feat) Add support for using @google/generative-ai JS with LiteLLM Proxy (#6899) * feat - allow using gemini js SDK with LiteLLM * add auth for gemini_proxy_route * basic local test for js * test cost tagging gemini js requests * add js sdk test for gemini with litellm * add docs on gemini JS SDK * run node.js tests * fix google ai studio tests * fix vertex js spend test --- .circleci/config.yml | 6 +- .../docs/pass_through/google_ai_studio.md | 128 +++++++++++++++++- litellm/proxy/_types.py | 1 + litellm/proxy/auth/user_api_key_auth.py | 10 ++ .../llm_passthrough_endpoints.py | 6 +- .../test_gemini_with_spend.test.js | 123 +++++++++++++++++ tests/pass_through_tests/test_local_gemini.js | 55 ++++++++ .../test_vertex_with_spend.test.js | 6 +- 8 files changed, 323 insertions(+), 12 deletions(-) create mode 100644 tests/pass_through_tests/test_gemini_with_spend.test.js create mode 100644 tests/pass_through_tests/test_local_gemini.js diff --git a/.circleci/config.yml b/.circleci/config.yml index d33f62cf3..56fb1fee1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1191,6 +1191,7 @@ jobs: -e DATABASE_URL=$PROXY_DATABASE_URL \ -e LITELLM_MASTER_KEY="sk-1234" \ -e OPENAI_API_KEY=$OPENAI_API_KEY \ + -e GEMINI_API_KEY=$GEMINI_API_KEY \ -e ANTHROPIC_API_KEY=$ANTHROPIC_API_KEY \ -e LITELLM_LICENSE=$LITELLM_LICENSE \ --name my-app \ @@ -1228,12 +1229,13 @@ jobs: name: Install Node.js dependencies command: | npm install @google-cloud/vertexai + npm install @google/generative-ai npm install --save-dev jest - run: - name: Run Vertex AI tests + name: Run Vertex AI, Google AI Studio Node.js tests command: | - npx jest tests/pass_through_tests/test_vertex.test.js --verbose + npx jest tests/pass_through_tests --verbose no_output_timeout: 30m - run: name: Run tests diff --git a/docs/my-website/docs/pass_through/google_ai_studio.md b/docs/my-website/docs/pass_through/google_ai_studio.md index cc7f9ce71..ee5eecc19 100644 --- a/docs/my-website/docs/pass_through/google_ai_studio.md +++ b/docs/my-website/docs/pass_through/google_ai_studio.md @@ -1,12 +1,21 @@ +import Image from '@theme/IdealImage'; +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + + # Google AI Studio SDK Pass-through endpoints for Google AI Studio - call provider-specific endpoint, in native format (no translation). -Just replace `https://generativelanguage.googleapis.com` with `LITELLM_PROXY_BASE_URL/gemini` 🚀 +Just replace `https://generativelanguage.googleapis.com` with `LITELLM_PROXY_BASE_URL/gemini` #### **Example Usage** + + + + ```bash -http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:countTokens?key=sk-anything' \ +curl 'http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:countTokens?key=sk-anything' \ -H 'Content-Type: application/json' \ -d '{ "contents": [{ @@ -17,6 +26,53 @@ http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:countTokens?key=sk-any }' ``` + + + +```javascript +const { GoogleGenerativeAI } = require("@google/generative-ai"); + +const modelParams = { + model: 'gemini-pro', +}; + +const requestOptions = { + baseUrl: 'http://localhost:4000/gemini', // http:///gemini +}; + +const genAI = new GoogleGenerativeAI("sk-1234"); // litellm proxy API key +const model = genAI.getGenerativeModel(modelParams, requestOptions); + +async function main() { + try { + const result = await model.generateContent("Explain how AI works"); + console.log(result.response.text()); + } catch (error) { + console.error('Error:', error); + } +} + +// For streaming responses +async function main_streaming() { + try { + const streamingResult = await model.generateContentStream("Explain how AI works"); + for await (const chunk of streamingResult.stream) { + console.log('Stream chunk:', JSON.stringify(chunk)); + } + const aggregatedResponse = await streamingResult.response; + console.log('Aggregated response:', JSON.stringify(aggregatedResponse)); + } catch (error) { + console.error('Error:', error); + } +} + +main(); +// main_streaming(); +``` + + + + Supports **ALL** Google AI Studio Endpoints (including streaming). [**See All Google AI Studio Endpoints**](https://ai.google.dev/api) @@ -166,14 +222,14 @@ curl -X POST "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5 ``` -## Advanced - Use with Virtual Keys +## Advanced Pre-requisites - [Setup proxy with DB](../proxy/virtual_keys.md#setup) Use this, to avoid giving developers the raw Google AI Studio key, but still letting them use Google AI Studio endpoints. -### Usage +### Use with Virtual Keys 1. Setup environment @@ -220,4 +276,66 @@ http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:countTokens?key=sk-123 }] }] }' -``` \ No newline at end of file +``` + + +### Send `tags` in request headers + +Use this if you want `tags` to be tracked in the LiteLLM DB and on logging callbacks. + +Pass tags in request headers as a comma separated list. In the example below the following tags will be tracked + +``` +tags: ["gemini-js-sdk", "pass-through-endpoint"] +``` + + + + +```bash +curl 'http://0.0.0.0:4000/gemini/v1beta/models/gemini-1.5-flash:generateContent?key=sk-anything' \ +-H 'Content-Type: application/json' \ +-H 'tags: gemini-js-sdk,pass-through-endpoint' \ +-d '{ + "contents": [{ + "parts":[{ + "text": "The quick brown fox jumps over the lazy dog." + }] + }] +}' +``` + + + + +```javascript +const { GoogleGenerativeAI } = require("@google/generative-ai"); + +const modelParams = { + model: 'gemini-pro', +}; + +const requestOptions = { + baseUrl: 'http://localhost:4000/gemini', // http:///gemini + customHeaders: { + "tags": "gemini-js-sdk,pass-through-endpoint" + } +}; + +const genAI = new GoogleGenerativeAI("sk-1234"); +const model = genAI.getGenerativeModel(modelParams, requestOptions); + +async function main() { + try { + const result = await model.generateContent("Explain how AI works"); + console.log(result.response.text()); + } catch (error) { + console.error('Error:', error); + } +} + +main(); +``` + + + diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 74e82b0ea..72d6c84c9 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2111,6 +2111,7 @@ class SpecialHeaders(enum.Enum): openai_authorization = "Authorization" azure_authorization = "API-Key" anthropic_authorization = "x-api-key" + google_ai_studio_authorization = "x-goog-api-key" class LitellmDataForBackendLLMCall(TypedDict, total=False): diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 669661e94..d19215245 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -95,6 +95,11 @@ anthropic_api_key_header = APIKeyHeader( auto_error=False, description="If anthropic client used.", ) +google_ai_studio_api_key_header = APIKeyHeader( + name=SpecialHeaders.google_ai_studio_authorization.value, + auto_error=False, + description="If google ai studio client used.", +) def _get_bearer_token( @@ -197,6 +202,9 @@ async def user_api_key_auth( # noqa: PLR0915 anthropic_api_key_header: Optional[str] = fastapi.Security( anthropic_api_key_header ), + google_ai_studio_api_key_header: Optional[str] = fastapi.Security( + google_ai_studio_api_key_header + ), ) -> UserAPIKeyAuth: from litellm.proxy.proxy_server import ( general_settings, @@ -233,6 +241,8 @@ async def user_api_key_auth( # noqa: PLR0915 api_key = azure_api_key_header elif isinstance(anthropic_api_key_header, str): api_key = anthropic_api_key_header + elif isinstance(google_ai_studio_api_key_header, str): + api_key = google_ai_studio_api_key_header elif pass_through_endpoints is not None: for endpoint in pass_through_endpoints: if endpoint.get("path", "") == route: diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index 3f4643afc..274ffda5b 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -61,10 +61,12 @@ async def gemini_proxy_route( fastapi_response: Response, ): ## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY - api_key = request.query_params.get("key") + google_ai_studio_api_key = request.query_params.get("key") or request.headers.get( + "x-goog-api-key" + ) user_api_key_dict = await user_api_key_auth( - request=request, api_key="Bearer {}".format(api_key) + request=request, api_key=f"Bearer {google_ai_studio_api_key}" ) base_target_url = "https://generativelanguage.googleapis.com" diff --git a/tests/pass_through_tests/test_gemini_with_spend.test.js b/tests/pass_through_tests/test_gemini_with_spend.test.js new file mode 100644 index 000000000..d02237fe3 --- /dev/null +++ b/tests/pass_through_tests/test_gemini_with_spend.test.js @@ -0,0 +1,123 @@ +const { GoogleGenerativeAI } = require("@google/generative-ai"); +const fs = require('fs'); +const path = require('path'); + +// Import fetch if the SDK uses it +const originalFetch = global.fetch || require('node-fetch'); + +let lastCallId; + +// Monkey-patch the fetch used internally +global.fetch = async function patchedFetch(url, options) { + const response = await originalFetch(url, options); + + // Store the call ID if it exists + lastCallId = response.headers.get('x-litellm-call-id'); + + return response; +}; + +describe('Gemini AI Tests', () => { + test('should successfully generate non-streaming content with tags', async () => { + const genAI = new GoogleGenerativeAI("sk-1234"); // litellm proxy API key + + const requestOptions = { + baseUrl: 'http://127.0.0.1:4000/gemini', + customHeaders: { + "tags": "gemini-js-sdk,pass-through-endpoint" + } + }; + + const model = genAI.getGenerativeModel({ + model: 'gemini-pro' + }, requestOptions); + + const prompt = 'Say "hello test" and nothing else'; + + const result = await model.generateContent(prompt); + expect(result).toBeDefined(); + + // Use the captured callId + const callId = lastCallId; + console.log("Captured Call ID:", callId); + + // Wait for spend to be logged + await new Promise(resolve => setTimeout(resolve, 15000)); + + // Check spend logs + const spendResponse = await fetch( + `http://127.0.0.1:4000/spend/logs?request_id=${callId}`, + { + headers: { + 'Authorization': 'Bearer sk-1234' + } + } + ); + + const spendData = await spendResponse.json(); + console.log("spendData", spendData) + expect(spendData).toBeDefined(); + expect(spendData[0].request_id).toBe(callId); + expect(spendData[0].call_type).toBe('pass_through_endpoint'); + expect(spendData[0].request_tags).toEqual(['gemini-js-sdk', 'pass-through-endpoint']); + expect(spendData[0].metadata).toHaveProperty('user_api_key'); + expect(spendData[0].model).toContain('gemini'); + expect(spendData[0].spend).toBeGreaterThan(0); + }, 25000); + + test('should successfully generate streaming content with tags', async () => { + const genAI = new GoogleGenerativeAI("sk-1234"); // litellm proxy API key + + const requestOptions = { + baseUrl: 'http://127.0.0.1:4000/gemini', + customHeaders: { + "tags": "gemini-js-sdk,pass-through-endpoint" + } + }; + + const model = genAI.getGenerativeModel({ + model: 'gemini-pro' + }, requestOptions); + + const prompt = 'Say "hello test" and nothing else'; + + const streamingResult = await model.generateContentStream(prompt); + expect(streamingResult).toBeDefined(); + + for await (const chunk of streamingResult.stream) { + console.log('stream chunk:', JSON.stringify(chunk)); + expect(chunk).toBeDefined(); + } + + const aggregatedResponse = await streamingResult.response; + console.log('aggregated response:', JSON.stringify(aggregatedResponse)); + expect(aggregatedResponse).toBeDefined(); + + // Use the captured callId + const callId = lastCallId; + console.log("Captured Call ID:", callId); + + // Wait for spend to be logged + await new Promise(resolve => setTimeout(resolve, 15000)); + + // Check spend logs + const spendResponse = await fetch( + `http://127.0.0.1:4000/spend/logs?request_id=${callId}`, + { + headers: { + 'Authorization': 'Bearer sk-1234' + } + } + ); + + const spendData = await spendResponse.json(); + console.log("spendData", spendData) + expect(spendData).toBeDefined(); + expect(spendData[0].request_id).toBe(callId); + expect(spendData[0].call_type).toBe('pass_through_endpoint'); + expect(spendData[0].request_tags).toEqual(['gemini-js-sdk', 'pass-through-endpoint']); + expect(spendData[0].metadata).toHaveProperty('user_api_key'); + expect(spendData[0].model).toContain('gemini'); + expect(spendData[0].spend).toBeGreaterThan(0); + }, 25000); +}); diff --git a/tests/pass_through_tests/test_local_gemini.js b/tests/pass_through_tests/test_local_gemini.js new file mode 100644 index 000000000..7043a5ab4 --- /dev/null +++ b/tests/pass_through_tests/test_local_gemini.js @@ -0,0 +1,55 @@ +const { GoogleGenerativeAI, ModelParams, RequestOptions } = require("@google/generative-ai"); + +const modelParams = { + model: 'gemini-pro', +}; + +const requestOptions = { + baseUrl: 'http://127.0.0.1:4000/gemini', + customHeaders: { + "tags": "gemini-js-sdk,gemini-pro" + } +}; + +const genAI = new GoogleGenerativeAI("sk-1234"); // litellm proxy API key +const model = genAI.getGenerativeModel(modelParams, requestOptions); + +const testPrompt = "Explain how AI works"; + +async function main() { + console.log("making request") + try { + const result = await model.generateContent(testPrompt); + console.log(result.response.text()); + } catch (error) { + console.error('Error details:', { + name: error.name, + message: error.message, + cause: error.cause, + // Check if there's a network error + isNetworkError: error instanceof TypeError && error.message === 'fetch failed' + }); + + // Check if the server is running + if (error instanceof TypeError && error.message === 'fetch failed') { + console.error('Make sure your local server is running at http://localhost:4000'); + } + } +} + + +async function main_streaming() { + try { + const streamingResult = await model.generateContentStream(testPrompt); + for await (const item of streamingResult.stream) { + console.log('stream chunk: ', JSON.stringify(item)); + } + const aggregatedResponse = await streamingResult.response; + console.log('aggregated response: ', JSON.stringify(aggregatedResponse)); + } catch (error) { + console.error('Error details:', error); + } +} + +// main(); +main_streaming(); \ No newline at end of file diff --git a/tests/pass_through_tests/test_vertex_with_spend.test.js b/tests/pass_through_tests/test_vertex_with_spend.test.js index 8a5b91557..d49b1eda2 100644 --- a/tests/pass_through_tests/test_vertex_with_spend.test.js +++ b/tests/pass_through_tests/test_vertex_with_spend.test.js @@ -60,9 +60,9 @@ function loadVertexAiCredentials() { } // Run credential loading before tests -// beforeAll(() => { -// loadVertexAiCredentials(); -// }); +beforeAll(() => { + loadVertexAiCredentials(); +});