diff --git a/docs/my-website/docs/adding_provider/directory_structure.md b/docs/my-website/docs/adding_provider/directory_structure.md new file mode 100644 index 0000000000..133aa6be03 --- /dev/null +++ b/docs/my-website/docs/adding_provider/directory_structure.md @@ -0,0 +1,21 @@ +# Directory Structure + +When adding a new provider, you need to create a directory for the provider that follows the following structure: + +``` +litellm/llms/ +└── provider_name/ + ├── completion/ + │ ├── handler.py + │ └── transformation.py + ├── chat/ + │ ├── handler.py + │ └── transformation.py + ├── embed/ + │ ├── handler.py + │ └── transformation.py + └── rerank/ + ├── handler.py + └── transformation.py +``` + diff --git a/docs/my-website/docs/adding_provider/new_rerank_provider.md b/docs/my-website/docs/adding_provider/new_rerank_provider.md new file mode 100644 index 0000000000..4fb78f69bd --- /dev/null +++ b/docs/my-website/docs/adding_provider/new_rerank_provider.md @@ -0,0 +1,81 @@ +# Add Rerank Provider + +LiteLLM **follows the Cohere Rerank API format** for all rerank providers. Here's how to add a new rerank provider: + +## 1. Create a transformation.py file + +Create a config class named `Config` that inherits from `BaseRerankConfig`: + +```python +from litellm.types.rerank import OptionalRerankParams, RerankRequest, RerankResponse +class YourProviderRerankConfig(BaseRerankConfig): + def get_supported_cohere_rerank_params(self, model: str) -> list: + return [ + "query", + "documents", + "top_n", + # ... other supported params + ] + + def transform_rerank_request(self, model: str, optional_rerank_params: OptionalRerankParams, headers: dict) -> dict: + # Transform request to RerankRequest spec + return rerank_request.model_dump(exclude_none=True) + + def transform_rerank_response(self, model: str, raw_response: httpx.Response, ...) -> RerankResponse: + # Transform provider response to RerankResponse + return RerankResponse(**raw_response_json) +``` + + +## 2. Register Your Provider +Add your provider to `litellm.utils.get_provider_rerank_config()`: + +```python +elif litellm.LlmProviders.YOUR_PROVIDER == provider: + return litellm.YourProviderRerankConfig() +``` + + +## 3. Add Provider to `rerank_api/main.py` + +Add a code block to handle when your provider is called. Your provider should use the `base_llm_http_handler.rerank` method + + +```python +elif _custom_llm_provider == "your_provider": + ... + response = base_llm_http_handler.rerank( + model=model, + custom_llm_provider=_custom_llm_provider, + optional_rerank_params=optional_rerank_params, + logging_obj=litellm_logging_obj, + timeout=optional_params.timeout, + api_key=dynamic_api_key or optional_params.api_key, + api_base=api_base, + _is_async=_is_async, + headers=headers or litellm.headers or {}, + client=client, + mod el_response=model_response, + ) + ... +``` + +## 4. Add Tests + +Add a test file to [`tests/llm_translation`](https://github.com/BerriAI/litellm/tree/main/tests/llm_translation) + +```python +def test_basic_rerank_cohere(): + response = litellm.rerank( + model="cohere/rerank-english-v3.0", + query="hello", + documents=["hello", "world"], + top_n=3, + ) + + print("re rank response: ", response) + + assert response.id is not None + assert response.results is not None +``` + diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index ad3914e541..931b0ecebe 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -320,6 +320,13 @@ const sidebars = { "load_test_rpm", ] }, + { + type: "category", + label: "Adding Providers", + items: [ + "adding_provider/directory_structure", + "adding_provider/new_rerank_provider"], + }, { type: "category", label: "Logging & Observability", diff --git a/litellm/__init__.py b/litellm/__init__.py index 0bd192d84f..1684aed8b2 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -126,6 +126,7 @@ azure_key: Optional[str] = None anthropic_key: Optional[str] = None replicate_key: Optional[str] = None cohere_key: Optional[str] = None +infinity_key: Optional[str] = None clarifai_key: Optional[str] = None maritalk_key: Optional[str] = None ai21_key: Optional[str] = None @@ -1025,6 +1026,7 @@ from .llms.replicate.chat.transformation import ReplicateConfig from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig from .llms.cohere.rerank.transformation import CohereRerankConfig from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig +from .llms.infinity.rerank.transformation import InfinityRerankConfig from .llms.clarifai.chat.transformation import ClarifaiConfig from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config from .llms.together_ai.chat import TogetherAIConfig diff --git a/litellm/llms/infinity/rerank/common_utils.py b/litellm/llms/infinity/rerank/common_utils.py new file mode 100644 index 0000000000..99477d1a33 --- /dev/null +++ b/litellm/llms/infinity/rerank/common_utils.py @@ -0,0 +1,19 @@ +import httpx + +from litellm.llms.base_llm.chat.transformation import BaseLLMException + + +class InfinityError(BaseLLMException): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url="https://github.com/michaelfeil/infinity" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + status_code=status_code, + message=message, + request=self.request, + response=self.response, + ) # Call the base class constructor with the parameters it needs diff --git a/litellm/llms/infinity/rerank/handler.py b/litellm/llms/infinity/rerank/handler.py new file mode 100644 index 0000000000..5b8a2c0c87 --- /dev/null +++ b/litellm/llms/infinity/rerank/handler.py @@ -0,0 +1,5 @@ +""" +Infinity Rerank - uses `llm_http_handler.py` to make httpx requests + +Request/Response transformation is handled in `transformation.py` +""" diff --git a/litellm/llms/infinity/rerank/transformation.py b/litellm/llms/infinity/rerank/transformation.py new file mode 100644 index 0000000000..2d34e5299a --- /dev/null +++ b/litellm/llms/infinity/rerank/transformation.py @@ -0,0 +1,91 @@ +""" +Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format. + +Why separate file? Make it easy to see how transformation works +""" + +import uuid +from typing import List, Optional + +import httpx + +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.cohere.rerank.transformation import CohereRerankConfig +from litellm.secret_managers.main import get_secret_str +from litellm.types.rerank import RerankBilledUnits, RerankResponseMeta, RerankTokens +from litellm.types.utils import RerankResponse + +from .common_utils import InfinityError + + +class InfinityRerankConfig(CohereRerankConfig): + def validate_environment( + self, + headers: dict, + model: str, + api_key: Optional[str] = None, + ) -> dict: + if api_key is None: + api_key = ( + get_secret_str("INFINITY_API_KEY") + or get_secret_str("INFINITY_API_KEY") + or litellm.infinity_key + ) + + default_headers = { + "Authorization": f"bearer {api_key}", + "accept": "application/json", + "content-type": "application/json", + } + + # If 'Authorization' is provided in headers, it overrides the default. + if "Authorization" in headers: + default_headers["Authorization"] = headers["Authorization"] + + # Merge other headers, overriding any default ones except Authorization + return {**default_headers, **headers} + + def transform_rerank_response( + self, + model: str, + raw_response: httpx.Response, + model_response: RerankResponse, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + request_data: dict = {}, + optional_params: dict = {}, + litellm_params: dict = {}, + ) -> RerankResponse: + """ + Transform Infinity rerank response + + No transformation required, Infinity follows Cohere API response format + """ + try: + raw_response_json = raw_response.json() + except Exception: + raise InfinityError( + message=raw_response.text, status_code=raw_response.status_code + ) + + _billed_units = RerankBilledUnits(**raw_response_json.get("usage", {})) + _tokens = RerankTokens( + input_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0), + output_tokens=( + raw_response_json.get("usage", {}).get("total_tokens", 0) + - raw_response_json.get("usage", {}).get("prompt_tokens", 0) + ), + ) + rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens) + + _results: Optional[List[dict]] = raw_response_json.get("results") + + if _results is None: + raise ValueError(f"No results found in the response={raw_response_json}") + + return RerankResponse( + id=raw_response_json.get("id") or str(uuid.uuid4()), + results=_results, # type: ignore + meta=rerank_meta, + ) # Return response diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index 72de2ca8ed..41ca98998c 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -76,7 +76,9 @@ def rerank( # noqa: PLR0915 model: str, query: str, documents: List[Union[str, Dict[str, Any]]], - custom_llm_provider: Optional[Literal["cohere", "together_ai", "azure_ai"]] = None, + custom_llm_provider: Optional[ + Literal["cohere", "together_ai", "azure_ai", "infinity"] + ] = None, top_n: Optional[int] = None, rank_fields: Optional[List[str]] = None, return_documents: Optional[bool] = True, @@ -188,6 +190,37 @@ def rerank( # noqa: PLR0915 or litellm.api_base or get_secret("AZURE_AI_API_BASE") # type: ignore ) + response = base_llm_http_handler.rerank( + model=model, + custom_llm_provider=_custom_llm_provider, + optional_rerank_params=optional_rerank_params, + logging_obj=litellm_logging_obj, + timeout=optional_params.timeout, + api_key=dynamic_api_key or optional_params.api_key, + api_base=api_base, + _is_async=_is_async, + headers=headers or litellm.headers or {}, + client=client, + model_response=model_response, + ) + elif _custom_llm_provider == "infinity": + # Implement Infinity rerank logic + api_key: Optional[str] = ( + dynamic_api_key or optional_params.api_key or litellm.api_key + ) + + api_base: Optional[str] = ( + dynamic_api_base + or optional_params.api_base + or litellm.api_base + or get_secret("INFINITY_API_BASE") # type: ignore + ) + + if api_base is None: + raise Exception( + "Invalid api base. api_base=None. Set in call or via `INFINITY_API_BASE` env var." + ) + response = base_llm_http_handler.rerank( model=model, custom_llm_provider=_custom_llm_provider, diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 9e299e62f6..487bc64bfe 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1741,6 +1741,7 @@ class LlmProviders(str, Enum): HOSTED_VLLM = "hosted_vllm" LM_STUDIO = "lm_studio" GALADRIEL = "galadriel" + INFINITY = "infinity" class LiteLLMLoggingBaseClass: diff --git a/litellm/utils.py b/litellm/utils.py index a16075ebc2..7fbd586d98 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6214,6 +6214,8 @@ class ProviderConfigManager: return litellm.CohereRerankConfig() elif litellm.LlmProviders.AZURE_AI == provider: return litellm.AzureAIRerankConfig() + elif litellm.LlmProviders.INFINITY == provider: + return litellm.InfinityRerankConfig() return litellm.CohereRerankConfig() diff --git a/tests/llm_translation/test_infinity.py b/tests/llm_translation/test_infinity.py new file mode 100644 index 0000000000..bab64a4da3 --- /dev/null +++ b/tests/llm_translation/test_infinity.py @@ -0,0 +1,151 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system-path + + +import litellm + +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock, patch + +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system-path +from test_rerank import assert_response_shape +import litellm + + +@pytest.mark.asyncio() +async def test_infinity_rerank(): + mock_response = AsyncMock() + + def return_val(): + return { + "id": "cmpl-mockid", + "results": [{"index": 0, "relevance_score": 0.95}], + "usage": {"prompt_tokens": 100, "total_tokens": 150}, + } + + mock_response.json = return_val + mock_response.headers = {"key": "value"} + mock_response.status_code = 200 + + expected_payload = { + "model": "rerank-model", + "query": "hello", + "top_n": 3, + "documents": ["hello", "world"], + } + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=mock_response, + ) as mock_post: + response = await litellm.arerank( + model="infinity/rerank-model", + query="hello", + documents=["hello", "world"], + top_n=3, + api_base="https://api.infinity.ai", + ) + + print("async re rank response: ", response) + + # Assert + mock_post.assert_called_once() + print("call args", mock_post.call_args) + args_to_api = mock_post.call_args.kwargs["data"] + _url = mock_post.call_args.kwargs["url"] + print("Arguments passed to API=", args_to_api) + print("url = ", _url) + assert _url == "https://api.infinity.ai/v1/rerank" + + request_data = json.loads(args_to_api) + assert request_data["query"] == expected_payload["query"] + assert request_data["documents"] == expected_payload["documents"] + assert request_data["top_n"] == expected_payload["top_n"] + assert request_data["model"] == expected_payload["model"] + + assert response.id is not None + assert response.results is not None + assert response.meta["tokens"]["input_tokens"] == 100 + assert ( + response.meta["tokens"]["output_tokens"] == 50 + ) # total_tokens - prompt_tokens + + assert_response_shape(response, custom_llm_provider="infinity") + + +@pytest.mark.asyncio() +async def test_infinity_rerank_with_env(monkeypatch): + # Set up mock response + mock_response = AsyncMock() + + def return_val(): + return { + "id": "cmpl-mockid", + "results": [{"index": 0, "relevance_score": 0.95}], + "usage": {"prompt_tokens": 100, "total_tokens": 150}, + } + + mock_response.json = return_val + mock_response.headers = {"key": "value"} + mock_response.status_code = 200 + + # Set environment variable + monkeypatch.setenv("INFINITY_API_BASE", "https://env.infinity.ai") + + expected_payload = { + "model": "rerank-model", + "query": "hello", + "top_n": 3, + "documents": ["hello", "world"], + } + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=mock_response, + ) as mock_post: + response = await litellm.arerank( + model="infinity/rerank-model", + query="hello", + documents=["hello", "world"], + top_n=3, + ) + + print("async re rank response: ", response) + + # Assert + mock_post.assert_called_once() + print("call args", mock_post.call_args) + args_to_api = mock_post.call_args.kwargs["data"] + _url = mock_post.call_args.kwargs["url"] + print("Arguments passed to API=", args_to_api) + print("url = ", _url) + assert _url == "https://env.infinity.ai/v1/rerank" + + request_data = json.loads(args_to_api) + assert request_data["query"] == expected_payload["query"] + assert request_data["documents"] == expected_payload["documents"] + assert request_data["top_n"] == expected_payload["top_n"] + assert request_data["model"] == expected_payload["model"] + + assert response.id is not None + assert response.results is not None + assert response.meta["tokens"]["input_tokens"] == 100 + assert ( + response.meta["tokens"]["output_tokens"] == 50 + ) # total_tokens - prompt_tokens + + assert_response_shape(response, custom_llm_provider="infinity")