diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index ef11d798e..536f6e2e5 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -4,10 +4,15 @@ model_list: model: openai/fake api_key: fake-key api_base: https://exampleopenaiendpoint-production.up.railway.app/ + - model_name: Salesforce/Llama-Rank-V1 + litellm_params: + model: together_ai/Salesforce/Llama-Rank-V1 + api_key: os.environ/TOGETHERAI_API_KEY + - model_name: rerank-english-v3.0 + litellm_params: + model: cohere/rerank-english-v3.0 + api_key: os.environ/COHERE_API_KEY # default off mode litellm_settings: - set_verbose: True - cache: True - cache_params: - mode: default_off + set_verbose: True \ No newline at end of file diff --git a/litellm/router.py b/litellm/router.py index 2d180f076..bc23972b6 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1641,6 +1641,104 @@ class Router: self.fail_calls[model_name] += 1 raise e + async def arerank(self, model: str, **kwargs): + try: + kwargs["model"] = model + kwargs["input"] = input + kwargs["original_function"] = self._arerank + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + timeout = kwargs.get("request_timeout", self.timeout) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + + response = await self.async_function_with_fallbacks(**kwargs) + + return response + except Exception as e: + asyncio.create_task( + send_llm_exception_alert( + litellm_router_instance=self, + request_kwargs=kwargs, + error_traceback_str=traceback.format_exc(), + original_exception=e, + ) + ) + raise e + + async def _arerank(self, model: str, **kwargs): + model_name = None + try: + verbose_router_logger.debug( + f"Inside _rerank()- model: {model}; kwargs: {kwargs}" + ) + deployment = await self.async_get_available_deployment( + model=model, + specific_deployment=kwargs.pop("specific_deployment", None), + ) + kwargs.setdefault("metadata", {}).update( + { + "deployment": deployment["litellm_params"]["model"], + "model_info": deployment.get("model_info", {}), + } + ) + kwargs["model_info"] = deployment.get("model_info", {}) + data = deployment["litellm_params"].copy() + model_name = data["model"] + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs and v is not None + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + + potential_model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="async" + ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client + self.total_calls[model_name] += 1 + + timeout = ( + data.get( + "timeout", None + ) # timeout set on litellm_params for this deployment + or self.timeout # timeout set on router + or kwargs.get( + "timeout", None + ) # this uses default_litellm_params when nothing is set + ) + + response = await litellm.arerank( + **{ + **data, + "caching": self.cache_responses, + "client": model_client, + "timeout": timeout, + **kwargs, + } + ) + + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.arerank(model={model_name})\033[32m 200 OK\033[0m" + ) + return response + except Exception as e: + verbose_router_logger.info( + f"litellm.arerank(model={model_name})\033[31m Exception {str(e)}\033[0m" + ) + if model_name is not None: + self.fail_calls[model_name] += 1 + raise e + def text_completion( self, model: str,