From 282b9a37e5f563202e61216fba065914a4995421 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 28 Nov 2023 16:34:16 -0800 Subject: [PATCH] (fix) router: passing client --- litellm/main.py | 4 +++- litellm/router.py | 37 ++++++++++++++++++++++--------------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index e0e4757c2f..ee88afe969 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1765,6 +1765,7 @@ def embedding( timeout=timeout, model_response=EmbeddingResponse(), optional_params=optional_params, + client=optional_params.pop("client", None) ) elif model in litellm.open_ai_embedding_models or custom_llm_provider == "openai": api_base = ( @@ -1798,7 +1799,8 @@ def embedding( logging_obj=logging, timeout=timeout, model_response=EmbeddingResponse(), - optional_params=optional_params + optional_params=optional_params, + client=optional_params.pop("client", None) ) elif model in litellm.cohere_embedding_models: cohere_key = ( diff --git a/litellm/router.py b/litellm/router.py index 2c170ee4bc..850b037a1a 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -102,7 +102,7 @@ class Router: self.default_litellm_params["timeout"] = timeout self.default_litellm_params["max_retries"] = 0 - + ### HEALTH CHECK THREAD ### if self.routing_strategy == "least-busy": self._start_health_check_thread() @@ -188,7 +188,8 @@ class Router: data["model"] = original_model_string[:index_of_model_id] else: data["model"] = original_model_string - return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs}) + model_client = deployment.get("client", None) + return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs}) except Exception as e: raise e @@ -232,7 +233,7 @@ class Router: data["model"] = original_model_string[:index_of_model_id] else: data["model"] = original_model_string - model_client = deployment["async_client"] + model_client = deployment.get("async_client", None) self.total_calls[original_model_string] +=1 response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs}) self.success_calls[original_model_string] +=1 @@ -301,8 +302,9 @@ class Router: data["model"] = original_model_string[:index_of_model_id] else: data["model"] = original_model_string + model_client = deployment.get("client", None) # call via litellm.embedding() - return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs}) + return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs}) async def aembedding(self, model: str, @@ -325,7 +327,9 @@ class Router: data["model"] = original_model_string[:index_of_model_id] else: data["model"] = original_model_string - return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs}) + model_client = deployment.get("async_client", None) + + return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs}) async def async_function_with_fallbacks(self, *args, **kwargs): """ @@ -801,23 +805,26 @@ class Router: self.model_list = model_list # we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works for model in self.model_list: - if "azure" in model["litellm_params"]["model"]: + litellm_params = model.get("litellm_params", {}) + model_name = litellm_params.get("model") + + if "azure" in model_name: model["async_client"] = openai.AsyncAzureOpenAI( - api_key= model["litellm_params"]["api_key"], - azure_endpoint = model["litellm_params"]["api_base"] + api_key=litellm_params.get("api_key"), + azure_endpoint=litellm_params.get("api_base") ) model["client"] = openai.AzureOpenAI( - api_key= model["litellm_params"]["api_key"], - azure_endpoint = model["litellm_params"]["api_base"] + api_key=litellm_params.get("api_key"), + azure_endpoint=litellm_params.get("api_base") ) - elif model["litellm_params"]["model"] in litellm.open_ai_chat_completion_models: + elif model_name in litellm.open_ai_chat_completion_models: model["async_client"] = openai.AsyncOpenAI( - api_key= model["litellm_params"]["api_key"], - base_ur = model["litellm_params"]["api_base"] + api_key=litellm_params.get("api_key"), + base_url=litellm_params.get("api_base") ) model["client"] = openai.OpenAI( - api_key= model["litellm_params"]["api_key"], - base_url = model["litellm_params"]["api_base"] + api_key=litellm_params.get("api_key"), + base_url=litellm_params.get("api_base") ) model_id = "" for key in model["litellm_params"]: