litellm/tests/llm_translation/base_rerank_unit_tests.py
Krish Dholakia e9aa492af3
LiteLLM Minor Fixes & Improvement (11/14/2024) (#6730)
* fix(ollama.py): fix get model info request

Fixes https://github.com/BerriAI/litellm/issues/6703

* feat(anthropic/chat/transformation.py): support passing user id to anthropic via openai 'user' param

* docs(anthropic.md): document all supported openai params for anthropic

* test: fix tests

* fix: fix tests

* feat(jina_ai/): add rerank support

Closes https://github.com/BerriAI/litellm/issues/6691

* test: handle service unavailable error

* fix(handler.py): refactor together ai rerank call

* test: update test to handle overloaded error

* test: fix test

* Litellm router trace (#6742)

* feat(router.py): add trace_id to parent functions - allows tracking retry/fallbacks

* feat(router.py): log trace id across retry/fallback logic

allows grouping llm logs for the same request

* test: fix tests

* fix: fix test

* fix(transformation.py): only set non-none stop_sequences

* Litellm router disable fallbacks (#6743)

* bump: version 1.52.6 → 1.52.7

* feat(router.py): enable dynamically disabling fallbacks

Allows for enabling/disabling fallbacks per key

* feat(litellm_pre_call_utils.py): support setting 'disable_fallbacks' on litellm key

* test: fix test

* fix(exception_mapping_utils.py): map 'model is overloaded' to internal server error

* test: handle gemini error

* test: fix test

* fix: new run
2024-11-15 01:02:54 +05:30

115 lines
3.5 KiB
Python

import asyncio
import httpx
import json
import pytest
import sys
from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import (
CustomStreamWrapper,
get_supported_openai_params,
get_optional_params,
)
# test_example.py
from abc import ABC, abstractmethod
def assert_response_shape(response, custom_llm_provider):
expected_response_shape = {"id": str, "results": list, "meta": dict}
expected_results_shape = {"index": int, "relevance_score": float}
expected_meta_shape = {"api_version": dict, "billed_units": dict}
expected_api_version_shape = {"version": str}
expected_billed_units_shape = {"search_units": int}
assert isinstance(response.id, expected_response_shape["id"])
assert isinstance(response.results, expected_response_shape["results"])
for result in response.results:
assert isinstance(result["index"], expected_results_shape["index"])
assert isinstance(
result["relevance_score"], expected_results_shape["relevance_score"]
)
assert isinstance(response.meta, expected_response_shape["meta"])
if custom_llm_provider == "cohere":
assert isinstance(
response.meta["api_version"], expected_meta_shape["api_version"]
)
assert isinstance(
response.meta["api_version"]["version"],
expected_api_version_shape["version"],
)
assert isinstance(
response.meta["billed_units"], expected_meta_shape["billed_units"]
)
assert isinstance(
response.meta["billed_units"]["search_units"],
expected_billed_units_shape["search_units"],
)
class BaseLLMRerankTest(ABC):
"""
Abstract base test class that enforces a common test across all test classes.
"""
@abstractmethod
def get_base_rerank_call_args(self) -> dict:
"""Must return the base rerank call args"""
pass
@abstractmethod
def get_custom_llm_provider(self) -> litellm.LlmProviders:
"""Must return the custom llm provider"""
pass
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_basic_rerank(self, sync_mode):
rerank_call_args = self.get_base_rerank_call_args()
custom_llm_provider = self.get_custom_llm_provider()
if sync_mode is True:
response = litellm.rerank(
**rerank_call_args,
query="hello",
documents=["hello", "world"],
top_n=3,
)
print("re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(
response=response, custom_llm_provider=custom_llm_provider.value
)
else:
response = await litellm.arerank(
**rerank_call_args,
query="hello",
documents=["hello", "world"],
top_n=3,
)
print("async re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(
response=response, custom_llm_provider=custom_llm_provider.value
)