diff --git a/docs/my-website/docs/proxy/user_keys.md b/docs/my-website/docs/proxy/user_keys.md new file mode 100644 index 000000000..a02ebafc4 --- /dev/null +++ b/docs/my-website/docs/proxy/user_keys.md @@ -0,0 +1,56 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# Pass in User Keys + +Send user keys to the proxy + + +Here's how to do it: + + + + +Pass in the litellm_params (E.g. api_key, api_base, etc.) via the `extra_body` parameter in the OpenAI client. + +```python +import openai +client = openai.OpenAI( + api_key="sk-1234", + base_url="http://0.0.0.0:8000" +) + +# request sent to model set on litellm proxy, `litellm --model` +response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [ + { + "role": "user", + "content": "this is a test request, write a short poem" + } +], + extra_body={"api_key": "my-bad-key"}) # 👈 User Key + +print(response) +``` + + + +```javascript +const { OpenAI } = require('openai'); + +const openai = new OpenAI({ + apiKey: "sk-1234", // This is the default and can be omitted + baseURL: "http://0.0.0.0:8000" +}); + +async function main() { + const chatCompletion = await openai.chat.completions.create({ + messages: [{ role: 'user', content: 'Say this is a test' }], + model: 'gpt-3.5-turbo', + api_key: "my-bad-key" // 👈 User Key + }); +} + +main(); +``` + + \ No newline at end of file diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 069faa48a..c0c259474 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -102,6 +102,7 @@ const sidebars = { "proxy/embedding", "proxy/load_balancing", "proxy/virtual_keys", + "proxy/user_keys", "proxy/model_management", "proxy/reliability", "proxy/health", diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f4a68eb97..48881d24a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1126,7 +1126,10 @@ async def completion( if llm_model_list is not None else [] ) - if ( + # skip router if user passed their key + if "api_key" in data: + response = await litellm.atext_completion(**data) + elif ( llm_router is not None and data["model"] in router_model_names ): # model in router model list response = await llm_router.atext_completion(**data) @@ -1259,7 +1262,10 @@ async def chat_completion( if llm_model_list is not None else [] ) - if ( + # skip router if user passed their key + if "api_key" in data: + response = await litellm.acompletion(**data) + elif ( llm_router is not None and data["model"] in router_model_names ): # model in router model list response = await llm_router.acompletion(**data) @@ -1414,7 +1420,10 @@ async def embeddings( user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings" ) ## ROUTE TO CORRECT ENDPOINT ## - if ( + # skip router if user passed their key + if "api_key" in data: + response = await litellm.aembedding(**data) + elif ( llm_router is not None and data["model"] in router_model_names ): # model in router model list response = await llm_router.aembedding(**data) @@ -1515,7 +1524,10 @@ async def image_generation( user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings" ) ## ROUTE TO CORRECT ENDPOINT ## - if ( + # skip router if user passed their key + if "api_key" in data: + response = await litellm.aimage_generation(**data) + elif ( llm_router is not None and data["model"] in router_model_names ): # model in router model list response = await llm_router.aimage_generation(**data) diff --git a/litellm/router.py b/litellm/router.py index 966e81c47..cd0263d8e 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1060,12 +1060,12 @@ class Router: "custom_llm_provider", None ) # i.e. azure metadata = kwargs.get("litellm_params", {}).get("metadata", None) - deployment_id = ( - kwargs.get("litellm_params", {}).get("model_info", {}).get("id", None) - ) - self._set_cooldown_deployments( - deployment_id - ) # setting deployment_id in cooldown deployments + _model_info = kwargs.get("litellm_params", {}).get("model_info", {}) + if isinstance(_model_info, dict): + deployment_id = _model_info.get("id", None) + self._set_cooldown_deployments( + deployment_id + ) # setting deployment_id in cooldown deployments if metadata: deployment = metadata.get("deployment", None) deployment_exceptions = self.model_exception_map.get(deployment, [])