forked from phoenix/litellm-mirror
* fix(ollama.py): fix get model info request Fixes https://github.com/BerriAI/litellm/issues/6703 * feat(anthropic/chat/transformation.py): support passing user id to anthropic via openai 'user' param * docs(anthropic.md): document all supported openai params for anthropic * test: fix tests * fix: fix tests * feat(jina_ai/): add rerank support Closes https://github.com/BerriAI/litellm/issues/6691 * test: handle service unavailable error * fix(handler.py): refactor together ai rerank call * test: update test to handle overloaded error * test: fix test * Litellm router trace (#6742) * feat(router.py): add trace_id to parent functions - allows tracking retry/fallbacks * feat(router.py): log trace id across retry/fallback logic allows grouping llm logs for the same request * test: fix tests * fix: fix test * fix(transformation.py): only set non-none stop_sequences * Litellm router disable fallbacks (#6743) * bump: version 1.52.6 → 1.52.7 * feat(router.py): enable dynamically disabling fallbacks Allows for enabling/disabling fallbacks per key * feat(litellm_pre_call_utils.py): support setting 'disable_fallbacks' on litellm key * test: fix test * fix(exception_mapping_utils.py): map 'model is overloaded' to internal server error * test: handle gemini error * test: fix test * fix: new run
115 lines
3.5 KiB
Python
115 lines
3.5 KiB
Python
import asyncio
|
|
import httpx
|
|
import json
|
|
import pytest
|
|
import sys
|
|
from typing import Any, Dict, List
|
|
from unittest.mock import MagicMock, Mock, patch
|
|
import os
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path
|
|
import litellm
|
|
from litellm.exceptions import BadRequestError
|
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
|
from litellm.utils import (
|
|
CustomStreamWrapper,
|
|
get_supported_openai_params,
|
|
get_optional_params,
|
|
)
|
|
|
|
# test_example.py
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
def assert_response_shape(response, custom_llm_provider):
|
|
expected_response_shape = {"id": str, "results": list, "meta": dict}
|
|
|
|
expected_results_shape = {"index": int, "relevance_score": float}
|
|
|
|
expected_meta_shape = {"api_version": dict, "billed_units": dict}
|
|
|
|
expected_api_version_shape = {"version": str}
|
|
|
|
expected_billed_units_shape = {"search_units": int}
|
|
|
|
assert isinstance(response.id, expected_response_shape["id"])
|
|
assert isinstance(response.results, expected_response_shape["results"])
|
|
for result in response.results:
|
|
assert isinstance(result["index"], expected_results_shape["index"])
|
|
assert isinstance(
|
|
result["relevance_score"], expected_results_shape["relevance_score"]
|
|
)
|
|
assert isinstance(response.meta, expected_response_shape["meta"])
|
|
|
|
if custom_llm_provider == "cohere":
|
|
|
|
assert isinstance(
|
|
response.meta["api_version"], expected_meta_shape["api_version"]
|
|
)
|
|
assert isinstance(
|
|
response.meta["api_version"]["version"],
|
|
expected_api_version_shape["version"],
|
|
)
|
|
assert isinstance(
|
|
response.meta["billed_units"], expected_meta_shape["billed_units"]
|
|
)
|
|
assert isinstance(
|
|
response.meta["billed_units"]["search_units"],
|
|
expected_billed_units_shape["search_units"],
|
|
)
|
|
|
|
|
|
class BaseLLMRerankTest(ABC):
|
|
"""
|
|
Abstract base test class that enforces a common test across all test classes.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def get_base_rerank_call_args(self) -> dict:
|
|
"""Must return the base rerank call args"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
|
"""Must return the custom llm provider"""
|
|
pass
|
|
|
|
@pytest.mark.asyncio()
|
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
|
async def test_basic_rerank(self, sync_mode):
|
|
rerank_call_args = self.get_base_rerank_call_args()
|
|
custom_llm_provider = self.get_custom_llm_provider()
|
|
if sync_mode is True:
|
|
response = litellm.rerank(
|
|
**rerank_call_args,
|
|
query="hello",
|
|
documents=["hello", "world"],
|
|
top_n=3,
|
|
)
|
|
|
|
print("re rank response: ", response)
|
|
|
|
assert response.id is not None
|
|
assert response.results is not None
|
|
|
|
assert_response_shape(
|
|
response=response, custom_llm_provider=custom_llm_provider.value
|
|
)
|
|
else:
|
|
response = await litellm.arerank(
|
|
**rerank_call_args,
|
|
query="hello",
|
|
documents=["hello", "world"],
|
|
top_n=3,
|
|
)
|
|
|
|
print("async re rank response: ", response)
|
|
|
|
assert response.id is not None
|
|
assert response.results is not None
|
|
|
|
assert_response_shape(
|
|
response=response, custom_llm_provider=custom_llm_provider.value
|
|
)
|