From 51ec270501765fc0454220422b583275004e1d78 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 14 Nov 2024 18:11:18 +0530 Subject: [PATCH] feat(jina_ai/): add rerank support Closes https://github.com/BerriAI/litellm/issues/6691 --- .../llms/jina_ai/embedding/transformation.py | 2 +- litellm/llms/jina_ai/rerank/handler.py | 96 +++++++++++++ litellm/llms/jina_ai/rerank/transformation.py | 36 +++++ litellm/main.py | 2 +- litellm/proxy/_new_secret_config.yaml | 127 ++---------------- litellm/rerank_api/main.py | 18 +++ litellm/types/rerank.py | 19 ++- .../llm_translation/base_rerank_unit_tests.py | 115 ++++++++++++++++ tests/llm_translation/test_jina_ai.py | 23 ++++ tests/local_testing/test_get_llm_provider.py | 2 +- 10 files changed, 319 insertions(+), 121 deletions(-) create mode 100644 litellm/llms/jina_ai/rerank/handler.py create mode 100644 litellm/llms/jina_ai/rerank/transformation.py create mode 100644 tests/llm_translation/base_rerank_unit_tests.py create mode 100644 tests/llm_translation/test_jina_ai.py diff --git a/litellm/llms/jina_ai/embedding/transformation.py b/litellm/llms/jina_ai/embedding/transformation.py index 26ff58878..97b7b2cfa 100644 --- a/litellm/llms/jina_ai/embedding/transformation.py +++ b/litellm/llms/jina_ai/embedding/transformation.py @@ -76,4 +76,4 @@ class JinaAIEmbeddingConfig: or get_secret_str("JINA_AI_API_KEY") or get_secret_str("JINA_AI_TOKEN") ) - return LlmProviders.OPENAI_LIKE.value, api_base, dynamic_api_key + return LlmProviders.JINA_AI.value, api_base, dynamic_api_key diff --git a/litellm/llms/jina_ai/rerank/handler.py b/litellm/llms/jina_ai/rerank/handler.py new file mode 100644 index 000000000..a2cfdd49e --- /dev/null +++ b/litellm/llms/jina_ai/rerank/handler.py @@ -0,0 +1,96 @@ +""" +Re rank api + +LiteLLM supports the re rank API format, no paramter transformation occurs +""" + +import uuid +from typing import Any, Dict, List, Optional, Union + +import httpx +from pydantic import BaseModel + +import litellm +from litellm.llms.base import BaseLLM +from litellm.llms.custom_httpx.http_handler import ( + _get_httpx_client, + get_async_httpx_client, +) +from litellm.llms.jina_ai.rerank.transformation import JinaAIRerankConfig +from litellm.types.rerank import RerankRequest, RerankResponse + + +class JinaAIRerank(BaseLLM): + def rerank( + self, + model: str, + api_key: str, + query: str, + documents: List[Union[str, Dict[str, Any]]], + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = True, + max_chunks_per_doc: Optional[int] = None, + _is_async: Optional[bool] = False, + ) -> RerankResponse: + client = _get_httpx_client() + + request_data = RerankRequest( + model=model, + query=query, + top_n=top_n, + documents=documents, + rank_fields=rank_fields, + return_documents=return_documents, + ) + + # exclude None values from request_data + request_data_dict = request_data.dict(exclude_none=True) + + if _is_async: + return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method + + response = client.post( + "https://api.jina.ai/v1/rerank", + headers={ + "accept": "application/json", + "content-type": "application/json", + "authorization": f"Bearer {api_key}", + }, + json=request_data_dict, + ) + + if response.status_code != 200: + raise Exception(response.text) + + _json_response = response.json() + + return JinaAIRerankConfig()._transform_response(_json_response) + + async def async_rerank( # New async method + self, + request_data_dict: Dict[str, Any], + api_key: str, + ) -> RerankResponse: + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.JINA_AI + ) # Use async client + + response = await client.post( + "https://api.jina.ai/v1/rerank", + headers={ + "accept": "application/json", + "content-type": "application/json", + "authorization": f"Bearer {api_key}", + }, + json=request_data_dict, + ) + + if response.status_code != 200: + raise Exception(response.text) + + _json_response = response.json() + + return JinaAIRerankConfig()._transform_response(_json_response) + + pass diff --git a/litellm/llms/jina_ai/rerank/transformation.py b/litellm/llms/jina_ai/rerank/transformation.py new file mode 100644 index 000000000..82039a15b --- /dev/null +++ b/litellm/llms/jina_ai/rerank/transformation.py @@ -0,0 +1,36 @@ +""" +Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format. + +Why separate file? Make it easy to see how transformation works + +Docs - https://jina.ai/reranker +""" + +import uuid +from typing import List, Optional + +from litellm.types.rerank import ( + RerankBilledUnits, + RerankResponse, + RerankResponseMeta, + RerankTokens, +) + + +class JinaAIRerankConfig: + def _transform_response(self, response: dict) -> RerankResponse: + + _billed_units = RerankBilledUnits(**response.get("usage", {})) + _tokens = RerankTokens(**response.get("usage", {})) + rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens) + + _results: Optional[List[dict]] = response.get("results") + + if _results is None: + raise ValueError(f"No results found in the response={response}") + + return RerankResponse( + id=response.get("id") or str(uuid.uuid4()), + results=_results, + meta=rerank_meta, + ) # Return response diff --git a/litellm/main.py b/litellm/main.py index afb46c698..ad8f791c3 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3455,7 +3455,7 @@ def embedding( # noqa: PLR0915 client=client, aembedding=aembedding, ) - elif custom_llm_provider == "openai_like": + elif custom_llm_provider == "openai_like" or custom_llm_provider == "jina_ai": api_base = ( api_base or litellm.api_base or get_secret_str("OPENAI_LIKE_API_BASE") ) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 911f15b86..b06a9e667 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,122 +1,15 @@ model_list: - - model_name: "*" - litellm_params: - model: claude-3-5-sonnet-20240620 - api_key: os.environ/ANTHROPIC_API_KEY - - model_name: claude-3-5-sonnet-aihubmix - litellm_params: - model: openai/claude-3-5-sonnet-20240620 - input_cost_per_token: 0.000003 # 3$/M - output_cost_per_token: 0.000015 # 15$/M - api_base: "https://exampleopenaiendpoint-production.up.railway.app" - api_key: my-fake-key - - model_name: fake-openai-endpoint-2 - litellm_params: - model: openai/my-fake-model - api_key: my-fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ - stream_timeout: 0.001 - timeout: 1 - rpm: 1 - - model_name: fake-openai-endpoint - litellm_params: - model: openai/my-fake-model - api_key: my-fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ - ## bedrock chat completions - - model_name: "*anthropic.claude*" - litellm_params: - model: bedrock/*anthropic.claude* - aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID - aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY - aws_region_name: os.environ/AWS_REGION_NAME - guardrailConfig: - "guardrailIdentifier": "h4dsqwhp6j66" - "guardrailVersion": "2" - "trace": "enabled" - -## bedrock embeddings - - model_name: "*amazon.titan-embed-*" - litellm_params: - model: bedrock/amazon.titan-embed-* - aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID - aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY - aws_region_name: os.environ/AWS_REGION_NAME - - model_name: "*cohere.embed-*" - litellm_params: - model: bedrock/cohere.embed-* - aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID - aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY - aws_region_name: os.environ/AWS_REGION_NAME - - - model_name: "bedrock/*" - litellm_params: - model: bedrock/* - aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID - aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY - aws_region_name: os.environ/AWS_REGION_NAME - + # GPT-4 Turbo Models - model_name: gpt-4 litellm_params: - model: azure/chatgpt-v-2 - api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ - api_version: "2023-05-15" - api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault - rpm: 480 - timeout: 300 - stream_timeout: 60 - -litellm_settings: - fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }] - # callbacks: ["otel", "prometheus"] - default_redis_batch_cache_expiry: 10 - # default_team_settings: - # - team_id: "dbe2f686-a686-4896-864a-4c3924458709" - # success_callback: ["langfuse"] - # langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_1 # Project 1 - # langfuse_secret: os.environ/LANGFUSE_PRIVATE_KEY_1 # Project 1 - -# litellm_settings: -# cache: True -# cache_params: -# type: redis - -# # disable caching on the actual API call -# supported_call_types: [] - -# # see https://docs.litellm.ai/docs/proxy/prod#3-use-redis-porthost-password-not-redis_url -# host: os.environ/REDIS_HOST -# port: os.environ/REDIS_PORT -# password: os.environ/REDIS_PASSWORD - -# # see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests -# # see https://docs.litellm.ai/docs/proxy/prometheus -# callbacks: ['otel'] + model: gpt-4 + - model_name: rerank-model + litellm_params: + model: jina_ai/jina-reranker-v2-base-multilingual -# # router_settings: -# # routing_strategy: latency-based-routing -# # routing_strategy_args: -# # # only assign 40% of traffic to the fastest deployment to avoid overloading it -# # lowest_latency_buffer: 0.4 - -# # # consider last five minutes of calls for latency calculation -# # ttl: 300 -# # redis_host: os.environ/REDIS_HOST -# # redis_port: os.environ/REDIS_PORT -# # redis_password: os.environ/REDIS_PASSWORD - -# # # see https://docs.litellm.ai/docs/proxy/prod#1-use-this-configyaml -# # general_settings: -# # master_key: os.environ/LITELLM_MASTER_KEY -# # database_url: os.environ/DATABASE_URL -# # disable_master_key_return: true -# # # alerting: ['slack', 'email'] -# # alerting: ['email'] - -# # # Batch write spend updates every 60s -# # proxy_batch_write_at: 60 - -# # # see https://docs.litellm.ai/docs/proxy/caching#advanced---user-api-key-cache-ttl -# # # our api keys rarely change -# # user_api_key_cache_ttl: 3600 +router_settings: + model_group_alias: + "gpt-4-turbo": # Aliased model name + model: "gpt-4" # Actual model name in 'model_list' + hidden: true \ No newline at end of file diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index a06aff135..70353acad 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -8,6 +8,7 @@ from litellm._logging import verbose_logger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.azure_ai.rerank import AzureAIRerank from litellm.llms.cohere.rerank import CohereRerank +from litellm.llms.jina_ai.rerank.handler import JinaAIRerank from litellm.llms.together_ai.rerank import TogetherAIRerank from litellm.secret_managers.main import get_secret from litellm.types.rerank import RerankRequest, RerankResponse @@ -19,6 +20,7 @@ from litellm.utils import client, exception_type, supports_httpx_timeout cohere_rerank = CohereRerank() together_rerank = TogetherAIRerank() azure_ai_rerank = AzureAIRerank() +jina_ai_rerank = JinaAIRerank() ################################################# @@ -247,7 +249,23 @@ def rerank( api_key=api_key, _is_async=_is_async, ) + elif _custom_llm_provider == "jina_ai": + if dynamic_api_key is None: + raise ValueError( + "Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment" + ) + response = jina_ai_rerank.rerank( + model=model, + api_key=dynamic_api_key, + query=query, + documents=documents, + top_n=top_n, + rank_fields=rank_fields, + return_documents=return_documents, + max_chunks_per_doc=max_chunks_per_doc, + _is_async=_is_async, + ) else: raise ValueError(f"Unsupported provider: {_custom_llm_provider}") diff --git a/litellm/types/rerank.py b/litellm/types/rerank.py index d016021fb..00b07ba13 100644 --- a/litellm/types/rerank.py +++ b/litellm/types/rerank.py @@ -7,6 +7,7 @@ https://docs.cohere.com/reference/rerank from typing import List, Optional, Union from pydantic import BaseModel, PrivateAttr +from typing_extensions import TypedDict class RerankRequest(BaseModel): @@ -19,10 +20,26 @@ class RerankRequest(BaseModel): max_chunks_per_doc: Optional[int] = None +class RerankBilledUnits(TypedDict, total=False): + search_units: int + total_tokens: int + + +class RerankTokens(TypedDict, total=False): + input_tokens: int + output_tokens: int + + +class RerankResponseMeta(TypedDict, total=False): + api_version: dict + billed_units: RerankBilledUnits + tokens: RerankTokens + + class RerankResponse(BaseModel): id: str results: List[dict] # Contains index and relevance_score - meta: Optional[dict] = None # Contains api_version and billed_units + meta: Optional[RerankResponseMeta] = None # Contains api_version and billed_units # Define private attributes using PrivateAttr _hidden_params: dict = PrivateAttr(default_factory=dict) diff --git a/tests/llm_translation/base_rerank_unit_tests.py b/tests/llm_translation/base_rerank_unit_tests.py new file mode 100644 index 000000000..2a8b80194 --- /dev/null +++ b/tests/llm_translation/base_rerank_unit_tests.py @@ -0,0 +1,115 @@ +import asyncio +import httpx +import json +import pytest +import sys +from typing import Any, Dict, List +from unittest.mock import MagicMock, Mock, patch +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +from litellm.exceptions import BadRequestError +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.utils import ( + CustomStreamWrapper, + get_supported_openai_params, + get_optional_params, +) + +# test_example.py +from abc import ABC, abstractmethod + + +def assert_response_shape(response, custom_llm_provider): + expected_response_shape = {"id": str, "results": list, "meta": dict} + + expected_results_shape = {"index": int, "relevance_score": float} + + expected_meta_shape = {"api_version": dict, "billed_units": dict} + + expected_api_version_shape = {"version": str} + + expected_billed_units_shape = {"search_units": int} + + assert isinstance(response.id, expected_response_shape["id"]) + assert isinstance(response.results, expected_response_shape["results"]) + for result in response.results: + assert isinstance(result["index"], expected_results_shape["index"]) + assert isinstance( + result["relevance_score"], expected_results_shape["relevance_score"] + ) + assert isinstance(response.meta, expected_response_shape["meta"]) + + if custom_llm_provider == "cohere": + + assert isinstance( + response.meta["api_version"], expected_meta_shape["api_version"] + ) + assert isinstance( + response.meta["api_version"]["version"], + expected_api_version_shape["version"], + ) + assert isinstance( + response.meta["billed_units"], expected_meta_shape["billed_units"] + ) + assert isinstance( + response.meta["billed_units"]["search_units"], + expected_billed_units_shape["search_units"], + ) + + +class BaseLLMRerankTest(ABC): + """ + Abstract base test class that enforces a common test across all test classes. + """ + + @abstractmethod + def get_base_rerank_call_args(self) -> dict: + """Must return the base rerank call args""" + pass + + @abstractmethod + def get_custom_llm_provider(self) -> litellm.LlmProviders: + """Must return the custom llm provider""" + pass + + @pytest.mark.asyncio() + @pytest.mark.parametrize("sync_mode", [True, False]) + async def test_basic_rerank(self, sync_mode): + rerank_call_args = self.get_base_rerank_call_args() + custom_llm_provider = self.get_custom_llm_provider() + if sync_mode is True: + response = litellm.rerank( + **rerank_call_args, + query="hello", + documents=["hello", "world"], + top_n=3, + ) + + print("re rank response: ", response) + + assert response.id is not None + assert response.results is not None + + assert_response_shape( + response=response, custom_llm_provider=custom_llm_provider.value + ) + else: + response = await litellm.arerank( + **rerank_call_args, + query="hello", + documents=["hello", "world"], + top_n=3, + ) + + print("async re rank response: ", response) + + assert response.id is not None + assert response.results is not None + + assert_response_shape( + response=response, custom_llm_provider=custom_llm_provider.value + ) diff --git a/tests/llm_translation/test_jina_ai.py b/tests/llm_translation/test_jina_ai.py new file mode 100644 index 000000000..c169b5587 --- /dev/null +++ b/tests/llm_translation/test_jina_ai.py @@ -0,0 +1,23 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + + +from base_rerank_unit_tests import BaseLLMRerankTest +import litellm + + +class TestJinaAI(BaseLLMRerankTest): + def get_custom_llm_provider(self) -> litellm.LlmProviders: + return litellm.LlmProviders.JINA_AI + + def get_base_rerank_call_args(self) -> dict: + return { + "model": "jina_ai/jina-reranker-v2-base-multilingual", + } diff --git a/tests/local_testing/test_get_llm_provider.py b/tests/local_testing/test_get_llm_provider.py index 6654c10c2..423ffe2fd 100644 --- a/tests/local_testing/test_get_llm_provider.py +++ b/tests/local_testing/test_get_llm_provider.py @@ -157,7 +157,7 @@ def test_get_llm_provider_jina_ai(): model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider( model="jina_ai/jina-embeddings-v3", ) - assert custom_llm_provider == "openai_like" + assert custom_llm_provider == "jina_ai" assert api_base == "https://api.jina.ai/v1" assert model == "jina-embeddings-v3"