forked from phoenix/litellm-mirror
feat(jina_ai/): add rerank support
Closes https://github.com/BerriAI/litellm/issues/6691
This commit is contained in:
parent
1988b13f46
commit
51ec270501
10 changed files with 319 additions and 121 deletions
|
@ -76,4 +76,4 @@ class JinaAIEmbeddingConfig:
|
|||
or get_secret_str("JINA_AI_API_KEY")
|
||||
or get_secret_str("JINA_AI_TOKEN")
|
||||
)
|
||||
return LlmProviders.OPENAI_LIKE.value, api_base, dynamic_api_key
|
||||
return LlmProviders.JINA_AI.value, api_base, dynamic_api_key
|
||||
|
|
96
litellm/llms/jina_ai/rerank/handler.py
Normal file
96
litellm/llms/jina_ai/rerank/handler.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
"""
|
||||
Re rank api
|
||||
|
||||
LiteLLM supports the re rank API format, no paramter transformation occurs
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.jina_ai.rerank.transformation import JinaAIRerankConfig
|
||||
from litellm.types.rerank import RerankRequest, RerankResponse
|
||||
|
||||
|
||||
class JinaAIRerank(BaseLLM):
|
||||
def rerank(
|
||||
self,
|
||||
model: str,
|
||||
api_key: str,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
_is_async: Optional[bool] = False,
|
||||
) -> RerankResponse:
|
||||
client = _get_httpx_client()
|
||||
|
||||
request_data = RerankRequest(
|
||||
model=model,
|
||||
query=query,
|
||||
top_n=top_n,
|
||||
documents=documents,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
)
|
||||
|
||||
# exclude None values from request_data
|
||||
request_data_dict = request_data.dict(exclude_none=True)
|
||||
|
||||
if _is_async:
|
||||
return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method
|
||||
|
||||
response = client.post(
|
||||
"https://api.jina.ai/v1/rerank",
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {api_key}",
|
||||
},
|
||||
json=request_data_dict,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.text)
|
||||
|
||||
_json_response = response.json()
|
||||
|
||||
return JinaAIRerankConfig()._transform_response(_json_response)
|
||||
|
||||
async def async_rerank( # New async method
|
||||
self,
|
||||
request_data_dict: Dict[str, Any],
|
||||
api_key: str,
|
||||
) -> RerankResponse:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.JINA_AI
|
||||
) # Use async client
|
||||
|
||||
response = await client.post(
|
||||
"https://api.jina.ai/v1/rerank",
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {api_key}",
|
||||
},
|
||||
json=request_data_dict,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.text)
|
||||
|
||||
_json_response = response.json()
|
||||
|
||||
return JinaAIRerankConfig()._transform_response(_json_response)
|
||||
|
||||
pass
|
36
litellm/llms/jina_ai/rerank/transformation.py
Normal file
36
litellm/llms/jina_ai/rerank/transformation.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
"""
|
||||
Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
Docs - https://jina.ai/reranker
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.types.rerank import (
|
||||
RerankBilledUnits,
|
||||
RerankResponse,
|
||||
RerankResponseMeta,
|
||||
RerankTokens,
|
||||
)
|
||||
|
||||
|
||||
class JinaAIRerankConfig:
|
||||
def _transform_response(self, response: dict) -> RerankResponse:
|
||||
|
||||
_billed_units = RerankBilledUnits(**response.get("usage", {}))
|
||||
_tokens = RerankTokens(**response.get("usage", {}))
|
||||
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
|
||||
|
||||
_results: Optional[List[dict]] = response.get("results")
|
||||
|
||||
if _results is None:
|
||||
raise ValueError(f"No results found in the response={response}")
|
||||
|
||||
return RerankResponse(
|
||||
id=response.get("id") or str(uuid.uuid4()),
|
||||
results=_results,
|
||||
meta=rerank_meta,
|
||||
) # Return response
|
|
@ -3455,7 +3455,7 @@ def embedding( # noqa: PLR0915
|
|||
client=client,
|
||||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "openai_like":
|
||||
elif custom_llm_provider == "openai_like" or custom_llm_provider == "jina_ai":
|
||||
api_base = (
|
||||
api_base or litellm.api_base or get_secret_str("OPENAI_LIKE_API_BASE")
|
||||
)
|
||||
|
|
|
@ -1,122 +1,15 @@
|
|||
model_list:
|
||||
- model_name: "*"
|
||||
litellm_params:
|
||||
model: claude-3-5-sonnet-20240620
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
- model_name: claude-3-5-sonnet-aihubmix
|
||||
litellm_params:
|
||||
model: openai/claude-3-5-sonnet-20240620
|
||||
input_cost_per_token: 0.000003 # 3$/M
|
||||
output_cost_per_token: 0.000015 # 15$/M
|
||||
api_base: "https://exampleopenaiendpoint-production.up.railway.app"
|
||||
api_key: my-fake-key
|
||||
- model_name: fake-openai-endpoint-2
|
||||
litellm_params:
|
||||
model: openai/my-fake-model
|
||||
api_key: my-fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
stream_timeout: 0.001
|
||||
timeout: 1
|
||||
rpm: 1
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/my-fake-model
|
||||
api_key: my-fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
## bedrock chat completions
|
||||
- model_name: "*anthropic.claude*"
|
||||
litellm_params:
|
||||
model: bedrock/*anthropic.claude*
|
||||
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name: os.environ/AWS_REGION_NAME
|
||||
guardrailConfig:
|
||||
"guardrailIdentifier": "h4dsqwhp6j66"
|
||||
"guardrailVersion": "2"
|
||||
"trace": "enabled"
|
||||
|
||||
## bedrock embeddings
|
||||
- model_name: "*amazon.titan-embed-*"
|
||||
litellm_params:
|
||||
model: bedrock/amazon.titan-embed-*
|
||||
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name: os.environ/AWS_REGION_NAME
|
||||
- model_name: "*cohere.embed-*"
|
||||
litellm_params:
|
||||
model: bedrock/cohere.embed-*
|
||||
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name: os.environ/AWS_REGION_NAME
|
||||
|
||||
- model_name: "bedrock/*"
|
||||
litellm_params:
|
||||
model: bedrock/*
|
||||
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name: os.environ/AWS_REGION_NAME
|
||||
|
||||
# GPT-4 Turbo Models
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_version: "2023-05-15"
|
||||
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
|
||||
rpm: 480
|
||||
timeout: 300
|
||||
stream_timeout: 60
|
||||
|
||||
litellm_settings:
|
||||
fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
|
||||
# callbacks: ["otel", "prometheus"]
|
||||
default_redis_batch_cache_expiry: 10
|
||||
# default_team_settings:
|
||||
# - team_id: "dbe2f686-a686-4896-864a-4c3924458709"
|
||||
# success_callback: ["langfuse"]
|
||||
# langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_1 # Project 1
|
||||
# langfuse_secret: os.environ/LANGFUSE_PRIVATE_KEY_1 # Project 1
|
||||
|
||||
# litellm_settings:
|
||||
# cache: True
|
||||
# cache_params:
|
||||
# type: redis
|
||||
|
||||
# # disable caching on the actual API call
|
||||
# supported_call_types: []
|
||||
|
||||
# # see https://docs.litellm.ai/docs/proxy/prod#3-use-redis-porthost-password-not-redis_url
|
||||
# host: os.environ/REDIS_HOST
|
||||
# port: os.environ/REDIS_PORT
|
||||
# password: os.environ/REDIS_PASSWORD
|
||||
|
||||
# # see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests
|
||||
# # see https://docs.litellm.ai/docs/proxy/prometheus
|
||||
# callbacks: ['otel']
|
||||
model: gpt-4
|
||||
- model_name: rerank-model
|
||||
litellm_params:
|
||||
model: jina_ai/jina-reranker-v2-base-multilingual
|
||||
|
||||
|
||||
# # router_settings:
|
||||
# # routing_strategy: latency-based-routing
|
||||
# # routing_strategy_args:
|
||||
# # # only assign 40% of traffic to the fastest deployment to avoid overloading it
|
||||
# # lowest_latency_buffer: 0.4
|
||||
|
||||
# # # consider last five minutes of calls for latency calculation
|
||||
# # ttl: 300
|
||||
# # redis_host: os.environ/REDIS_HOST
|
||||
# # redis_port: os.environ/REDIS_PORT
|
||||
# # redis_password: os.environ/REDIS_PASSWORD
|
||||
|
||||
# # # see https://docs.litellm.ai/docs/proxy/prod#1-use-this-configyaml
|
||||
# # general_settings:
|
||||
# # master_key: os.environ/LITELLM_MASTER_KEY
|
||||
# # database_url: os.environ/DATABASE_URL
|
||||
# # disable_master_key_return: true
|
||||
# # # alerting: ['slack', 'email']
|
||||
# # alerting: ['email']
|
||||
|
||||
# # # Batch write spend updates every 60s
|
||||
# # proxy_batch_write_at: 60
|
||||
|
||||
# # # see https://docs.litellm.ai/docs/proxy/caching#advanced---user-api-key-cache-ttl
|
||||
# # # our api keys rarely change
|
||||
# # user_api_key_cache_ttl: 3600
|
||||
router_settings:
|
||||
model_group_alias:
|
||||
"gpt-4-turbo": # Aliased model name
|
||||
model: "gpt-4" # Actual model name in 'model_list'
|
||||
hidden: true
|
|
@ -8,6 +8,7 @@ from litellm._logging import verbose_logger
|
|||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.azure_ai.rerank import AzureAIRerank
|
||||
from litellm.llms.cohere.rerank import CohereRerank
|
||||
from litellm.llms.jina_ai.rerank.handler import JinaAIRerank
|
||||
from litellm.llms.together_ai.rerank import TogetherAIRerank
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.rerank import RerankRequest, RerankResponse
|
||||
|
@ -19,6 +20,7 @@ from litellm.utils import client, exception_type, supports_httpx_timeout
|
|||
cohere_rerank = CohereRerank()
|
||||
together_rerank = TogetherAIRerank()
|
||||
azure_ai_rerank = AzureAIRerank()
|
||||
jina_ai_rerank = JinaAIRerank()
|
||||
#################################################
|
||||
|
||||
|
||||
|
@ -247,7 +249,23 @@ def rerank(
|
|||
api_key=api_key,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
elif _custom_llm_provider == "jina_ai":
|
||||
|
||||
if dynamic_api_key is None:
|
||||
raise ValueError(
|
||||
"Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment"
|
||||
)
|
||||
response = jina_ai_rerank.rerank(
|
||||
model=model,
|
||||
api_key=dynamic_api_key,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ https://docs.cohere.com/reference/rerank
|
|||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class RerankRequest(BaseModel):
|
||||
|
@ -19,10 +20,26 @@ class RerankRequest(BaseModel):
|
|||
max_chunks_per_doc: Optional[int] = None
|
||||
|
||||
|
||||
class RerankBilledUnits(TypedDict, total=False):
|
||||
search_units: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class RerankTokens(TypedDict, total=False):
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
|
||||
|
||||
class RerankResponseMeta(TypedDict, total=False):
|
||||
api_version: dict
|
||||
billed_units: RerankBilledUnits
|
||||
tokens: RerankTokens
|
||||
|
||||
|
||||
class RerankResponse(BaseModel):
|
||||
id: str
|
||||
results: List[dict] # Contains index and relevance_score
|
||||
meta: Optional[dict] = None # Contains api_version and billed_units
|
||||
meta: Optional[RerankResponseMeta] = None # Contains api_version and billed_units
|
||||
|
||||
# Define private attributes using PrivateAttr
|
||||
_hidden_params: dict = PrivateAttr(default_factory=dict)
|
||||
|
|
115
tests/llm_translation/base_rerank_unit_tests.py
Normal file
115
tests/llm_translation/base_rerank_unit_tests.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
import asyncio
|
||||
import httpx
|
||||
import json
|
||||
import pytest
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm.exceptions import BadRequestError
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
get_supported_openai_params,
|
||||
get_optional_params,
|
||||
)
|
||||
|
||||
# test_example.py
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
def assert_response_shape(response, custom_llm_provider):
|
||||
expected_response_shape = {"id": str, "results": list, "meta": dict}
|
||||
|
||||
expected_results_shape = {"index": int, "relevance_score": float}
|
||||
|
||||
expected_meta_shape = {"api_version": dict, "billed_units": dict}
|
||||
|
||||
expected_api_version_shape = {"version": str}
|
||||
|
||||
expected_billed_units_shape = {"search_units": int}
|
||||
|
||||
assert isinstance(response.id, expected_response_shape["id"])
|
||||
assert isinstance(response.results, expected_response_shape["results"])
|
||||
for result in response.results:
|
||||
assert isinstance(result["index"], expected_results_shape["index"])
|
||||
assert isinstance(
|
||||
result["relevance_score"], expected_results_shape["relevance_score"]
|
||||
)
|
||||
assert isinstance(response.meta, expected_response_shape["meta"])
|
||||
|
||||
if custom_llm_provider == "cohere":
|
||||
|
||||
assert isinstance(
|
||||
response.meta["api_version"], expected_meta_shape["api_version"]
|
||||
)
|
||||
assert isinstance(
|
||||
response.meta["api_version"]["version"],
|
||||
expected_api_version_shape["version"],
|
||||
)
|
||||
assert isinstance(
|
||||
response.meta["billed_units"], expected_meta_shape["billed_units"]
|
||||
)
|
||||
assert isinstance(
|
||||
response.meta["billed_units"]["search_units"],
|
||||
expected_billed_units_shape["search_units"],
|
||||
)
|
||||
|
||||
|
||||
class BaseLLMRerankTest(ABC):
|
||||
"""
|
||||
Abstract base test class that enforces a common test across all test classes.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_base_rerank_call_args(self) -> dict:
|
||||
"""Must return the base rerank call args"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||
"""Must return the custom llm provider"""
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
async def test_basic_rerank(self, sync_mode):
|
||||
rerank_call_args = self.get_base_rerank_call_args()
|
||||
custom_llm_provider = self.get_custom_llm_provider()
|
||||
if sync_mode is True:
|
||||
response = litellm.rerank(
|
||||
**rerank_call_args,
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=3,
|
||||
)
|
||||
|
||||
print("re rank response: ", response)
|
||||
|
||||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
|
||||
assert_response_shape(
|
||||
response=response, custom_llm_provider=custom_llm_provider.value
|
||||
)
|
||||
else:
|
||||
response = await litellm.arerank(
|
||||
**rerank_call_args,
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=3,
|
||||
)
|
||||
|
||||
print("async re rank response: ", response)
|
||||
|
||||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
|
||||
assert_response_shape(
|
||||
response=response, custom_llm_provider=custom_llm_provider.value
|
||||
)
|
23
tests/llm_translation/test_jina_ai.py
Normal file
23
tests/llm_translation/test_jina_ai.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
from base_rerank_unit_tests import BaseLLMRerankTest
|
||||
import litellm
|
||||
|
||||
|
||||
class TestJinaAI(BaseLLMRerankTest):
|
||||
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||
return litellm.LlmProviders.JINA_AI
|
||||
|
||||
def get_base_rerank_call_args(self) -> dict:
|
||||
return {
|
||||
"model": "jina_ai/jina-reranker-v2-base-multilingual",
|
||||
}
|
|
@ -157,7 +157,7 @@ def test_get_llm_provider_jina_ai():
|
|||
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||
model="jina_ai/jina-embeddings-v3",
|
||||
)
|
||||
assert custom_llm_provider == "openai_like"
|
||||
assert custom_llm_provider == "jina_ai"
|
||||
assert api_base == "https://api.jina.ai/v1"
|
||||
assert model == "jina-embeddings-v3"
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue