From f5ed4992dbd2e25ef5ac712a1f2189db123fa68c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 26 Dec 2023 13:16:11 +0530 Subject: [PATCH] fix(router.py): accept dynamic api key --- litellm/router.py | 78 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 71 insertions(+), 7 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 6e1b77748..966e81c47 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -251,7 +251,20 @@ class Router: kwargs[k] = v elif k == "metadata": kwargs[k].update(v) - model_client = self._get_client(deployment=deployment, kwargs=kwargs) + potential_model_client = self._get_client( + deployment=deployment, kwargs=kwargs + ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client + return litellm.completion( **{ **data, @@ -303,9 +316,19 @@ class Router: elif k == "metadata": kwargs[k].update(v) - model_client = self._get_client( + potential_model_client = self._get_client( deployment=deployment, kwargs=kwargs, client_type="async" ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client self.total_calls[model_name] += 1 response = await litellm.acompletion( **{ @@ -361,9 +384,20 @@ class Router: elif k == "metadata": kwargs[k].update(v) - model_client = self._get_client( + potential_model_client = self._get_client( deployment=deployment, kwargs=kwargs, client_type="async" ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client + self.total_calls[model_name] += 1 response = litellm.image_generation( **{ @@ -419,9 +453,20 @@ class Router: elif k == "metadata": kwargs[k].update(v) - model_client = self._get_client( + potential_model_client = self._get_client( deployment=deployment, kwargs=kwargs, client_type="async" ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client + self.total_calls[model_name] += 1 response = await litellm.aimage_generation( **{ @@ -554,8 +599,17 @@ class Router: kwargs[k] = v elif k == "metadata": kwargs[k].update(v) - model_client = self._get_client(deployment=deployment, kwargs=kwargs) - # call via litellm.embedding() + potential_model_client = self._get_client(deployment=deployment, kwargs=kwargs) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client return litellm.embedding( **{ **data, @@ -592,9 +646,19 @@ class Router: elif k == "metadata": kwargs[k].update(v) - model_client = self._get_client( + potential_model_client = self._get_client( deployment=deployment, kwargs=kwargs, client_type="async" ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client return await litellm.aembedding( **{