diff --git a/docs/my-website/docs/providers/custom_llm_server.md b/docs/my-website/docs/providers/custom_llm_server.md
index 6d2015010..2adb6a67c 100644
--- a/docs/my-website/docs/providers/custom_llm_server.md
+++ b/docs/my-website/docs/providers/custom_llm_server.md
@@ -251,6 +251,105 @@ Expected Response
}
```
+## Additional Parameters
+
+Additional parameters are passed inside `optional_params` key in the `completion` or `image_generation` function.
+
+Here's how to set this:
+
+
+
+
+```python
+import litellm
+from litellm import CustomLLM, completion, get_llm_provider
+
+
+class MyCustomLLM(CustomLLM):
+ def completion(self, *args, **kwargs) -> litellm.ModelResponse:
+ assert kwargs["optional_params"] == {"my_custom_param": "my-custom-param"} # 👈 CHECK HERE
+ return litellm.completion(
+ model="gpt-3.5-turbo",
+ messages=[{"role": "user", "content": "Hello world"}],
+ mock_response="Hi!",
+ ) # type: ignore
+
+my_custom_llm = MyCustomLLM()
+
+litellm.custom_provider_map = [ # 👈 KEY STEP - REGISTER HANDLER
+ {"provider": "my-custom-llm", "custom_handler": my_custom_llm}
+ ]
+
+resp = completion(model="my-custom-llm/my-model", my_custom_param="my-custom-param")
+```
+
+
+
+
+
+1. Setup your `custom_handler.py` file
+```python
+import litellm
+from litellm import CustomLLM
+from litellm.types.utils import ImageResponse, ImageObject
+
+
+class MyCustomLLM(CustomLLM):
+ async def aimage_generation(self, model: str, prompt: str, model_response: ImageResponse, optional_params: dict, logging_obj: Any, timeout: Optional[Union[float, httpx.Timeout]] = None, client: Optional[AsyncHTTPHandler] = None,) -> ImageResponse:
+ assert optional_params == {"my_custom_param": "my-custom-param"} # 👈 CHECK HERE
+ return ImageResponse(
+ created=int(time.time()),
+ data=[ImageObject(url="https://example.com/image.png")],
+ )
+
+my_custom_llm = MyCustomLLM()
+```
+
+
+2. Add to `config.yaml`
+
+In the config below, we pass
+
+python_filename: `custom_handler.py`
+custom_handler_instance_name: `my_custom_llm`. This is defined in Step 1
+
+custom_handler: `custom_handler.my_custom_llm`
+
+```yaml
+model_list:
+ - model_name: "test-model"
+ litellm_params:
+ model: "openai/text-embedding-ada-002"
+ - model_name: "my-custom-model"
+ litellm_params:
+ model: "my-custom-llm/my-model"
+ my_custom_param: "my-custom-param" # 👈 CUSTOM PARAM
+
+litellm_settings:
+ custom_provider_map:
+ - {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm}
+```
+
+```bash
+litellm --config /path/to/config.yaml
+```
+
+3. Test it!
+
+```bash
+curl -X POST 'http://0.0.0.0:4000/v1/images/generations' \
+-H 'Content-Type: application/json' \
+-H 'Authorization: Bearer sk-1234' \
+-d '{
+ "model": "my-custom-model",
+ "prompt": "A cute baby sea otter",
+}'
+```
+
+
+
+
+
## Custom Handler Spec
diff --git a/litellm/caching.py b/litellm/caching.py
index 91d9e6996..c9767b624 100644
--- a/litellm/caching.py
+++ b/litellm/caching.py
@@ -20,13 +20,13 @@ from datetime import timedelta
from enum import Enum
from typing import Any, List, Literal, Optional, Tuple, Union
-from openai._models import BaseModel as OpenAIObject
+from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
-from litellm.types.utils import all_litellm_params
+from litellm.types.utils import CachingSupportedCallTypes, all_litellm_params
def print_verbose(print_statement):
@@ -2139,20 +2139,7 @@ class Cache:
default_in_memory_ttl: Optional[float] = None,
default_in_redis_ttl: Optional[float] = None,
similarity_threshold: Optional[float] = None,
- supported_call_types: Optional[
- List[
- Literal[
- "completion",
- "acompletion",
- "embedding",
- "aembedding",
- "atranscription",
- "transcription",
- "atext_completion",
- "text_completion",
- ]
- ]
- ] = [
+ supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
"completion",
"acompletion",
"embedding",
@@ -2161,6 +2148,8 @@ class Cache:
"transcription",
"atext_completion",
"text_completion",
+ "arerank",
+ "rerank",
],
# s3 Bucket, boto3 configuration
s3_bucket_name: Optional[str] = None,
@@ -2353,9 +2342,20 @@ class Cache:
"file",
"language",
]
+ rerank_only_kwargs = [
+ "top_n",
+ "rank_fields",
+ "return_documents",
+ "max_chunks_per_doc",
+ "documents",
+ "query",
+ ]
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
combined_kwargs = (
- completion_kwargs + embedding_only_kwargs + transcription_only_kwargs
+ completion_kwargs
+ + embedding_only_kwargs
+ + transcription_only_kwargs
+ + rerank_only_kwargs
)
litellm_param_kwargs = all_litellm_params
for param in kwargs:
@@ -2557,7 +2557,7 @@ class Cache:
else:
cache_key = self.get_cache_key(*args, **kwargs)
if cache_key is not None:
- if isinstance(result, OpenAIObject):
+ if isinstance(result, BaseModel):
result = result.model_dump_json()
## DEFAULT TTL ##
@@ -2778,20 +2778,7 @@ def enable_cache(
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
- supported_call_types: Optional[
- List[
- Literal[
- "completion",
- "acompletion",
- "embedding",
- "aembedding",
- "atranscription",
- "transcription",
- "atext_completion",
- "text_completion",
- ]
- ]
- ] = [
+ supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
"completion",
"acompletion",
"embedding",
@@ -2800,6 +2787,8 @@ def enable_cache(
"transcription",
"atext_completion",
"text_completion",
+ "arerank",
+ "rerank",
],
**kwargs,
):
@@ -2847,20 +2836,7 @@ def update_cache(
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
- supported_call_types: Optional[
- List[
- Literal[
- "completion",
- "acompletion",
- "embedding",
- "aembedding",
- "atranscription",
- "transcription",
- "atext_completion",
- "text_completion",
- ]
- ]
- ] = [
+ supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
"completion",
"acompletion",
"embedding",
@@ -2869,6 +2845,8 @@ def update_cache(
"transcription",
"atext_completion",
"text_completion",
+ "arerank",
+ "rerank",
],
**kwargs,
):
diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py
index 177ff735c..c79b43422 100644
--- a/litellm/integrations/langfuse.py
+++ b/litellm/integrations/langfuse.py
@@ -191,7 +191,6 @@ class LangFuseLogger:
pass
# end of processing langfuse ########################
-
if (
level == "ERROR"
and status_message is not None
diff --git a/litellm/litellm_core_utils/exception_mapping_utils.py b/litellm/litellm_core_utils/exception_mapping_utils.py
index 61cca6e07..2572e695f 100644
--- a/litellm/litellm_core_utils/exception_mapping_utils.py
+++ b/litellm/litellm_core_utils/exception_mapping_utils.py
@@ -67,25 +67,6 @@ def get_error_message(error_obj) -> Optional[str]:
####### EXCEPTION MAPPING ################
-def _get_litellm_response_headers(
- original_exception: Exception,
-) -> Optional[httpx.Headers]:
- """
- Extract and return the response headers from a mapped exception, if present.
-
- Used for accurate retry logic.
- """
- _response_headers: Optional[httpx.Headers] = None
- try:
- _response_headers = getattr(
- original_exception, "litellm_response_headers", None
- )
- except Exception:
- return None
-
- return _response_headers
-
-
def _get_response_headers(original_exception: Exception) -> Optional[httpx.Headers]:
"""
Extract and return the response headers from an exception, if present.
@@ -96,8 +77,12 @@ def _get_response_headers(original_exception: Exception) -> Optional[httpx.Heade
try:
_response_headers = getattr(original_exception, "headers", None)
error_response = getattr(original_exception, "response", None)
- if _response_headers is None and error_response:
+ if not _response_headers and error_response:
_response_headers = getattr(error_response, "headers", None)
+ if not _response_headers:
+ _response_headers = getattr(
+ original_exception, "litellm_response_headers", None
+ )
except Exception:
return None
diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py
index a641be019..d3f15e6bc 100644
--- a/litellm/litellm_core_utils/litellm_logging.py
+++ b/litellm/litellm_core_utils/litellm_logging.py
@@ -84,6 +84,7 @@ from ..integrations.s3 import S3Logger
from ..integrations.supabase import Supabase
from ..integrations.traceloop import TraceloopLogger
from ..integrations.weights_biases import WeightsBiasesLogger
+from .exception_mapping_utils import _get_response_headers
try:
from ..proxy.enterprise.enterprise_callbacks.generic_api_callback import (
@@ -1813,6 +1814,7 @@ class Logging:
logging_obj=self,
status="failure",
error_str=str(exception),
+ original_exception=exception,
)
)
return start_time, end_time
@@ -2654,6 +2656,7 @@ def get_standard_logging_object_payload(
logging_obj: Logging,
status: StandardLoggingPayloadStatus,
error_str: Optional[str] = None,
+ original_exception: Optional[Exception] = None,
) -> Optional[StandardLoggingPayload]:
try:
if kwargs is None:
@@ -2670,6 +2673,19 @@ def get_standard_logging_object_payload(
else:
response_obj = {}
+ if original_exception is not None and hidden_params is None:
+ response_headers = _get_response_headers(original_exception)
+ if response_headers is not None:
+ hidden_params = dict(
+ StandardLoggingHiddenParams(
+ additional_headers=dict(response_headers),
+ model_id=None,
+ cache_key=None,
+ api_base=None,
+ response_cost=None,
+ )
+ )
+
# standardize this function to be used across, s3, dynamoDB, langfuse logging
litellm_params = kwargs.get("litellm_params", {})
proxy_server_request = litellm_params.get("proxy_server_request") or {}
diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json
index c1a5e6b67..962c629d7 100644
--- a/litellm/model_prices_and_context_window_backup.json
+++ b/litellm/model_prices_and_context_window_backup.json
@@ -5075,6 +5075,116 @@
"supports_function_calling": true,
"supports_tool_choice": false
},
+ "meta.llama3-2-1b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.0000001,
+ "output_cost_per_token": 0.0000001,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
+ "us.meta.llama3-2-1b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.0000001,
+ "output_cost_per_token": 0.0000001,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
+ "eu.meta.llama3-2-1b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.00000013,
+ "output_cost_per_token": 0.00000013,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
+ "meta.llama3-2-3b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.00000015,
+ "output_cost_per_token": 0.00000015,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
+ "us.meta.llama3-2-3b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.00000015,
+ "output_cost_per_token": 0.00000015,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
+ "eu.meta.llama3-2-3b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.00000019,
+ "output_cost_per_token": 0.00000019,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
+ "meta.llama3-2-11b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.00000035,
+ "output_cost_per_token": 0.00000035,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
+ "us.meta.llama3-2-11b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.00000035,
+ "output_cost_per_token": 0.00000035,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
+ "meta.llama3-2-90b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.000002,
+ "output_cost_per_token": 0.000002,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
+ "us.meta.llama3-2-90b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.000002,
+ "output_cost_per_token": 0.000002,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
"512-x-512/50-steps/stability.stable-diffusion-xl-v0": {
"max_tokens": 77,
"max_input_tokens": 77,
diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404.html
deleted file mode 100644
index 3387c26ce..000000000
--- a/litellm/proxy/_experimental/out/404.html
+++ /dev/null
@@ -1 +0,0 @@
-
404: This page could not be found.LiteLLM Dashboard404
This page could not be found.
\ No newline at end of file
diff --git a/litellm/proxy/_experimental/out/model_hub.html b/litellm/proxy/_experimental/out/model_hub.html
deleted file mode 100644
index bc65f3d70..000000000
--- a/litellm/proxy/_experimental/out/model_hub.html
+++ /dev/null
@@ -1 +0,0 @@
-LiteLLM Dashboard
\ No newline at end of file
diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html
deleted file mode 100644
index 9ee6afdcb..000000000
--- a/litellm/proxy/_experimental/out/onboarding.html
+++ /dev/null
@@ -1 +0,0 @@
-LiteLLM Dashboard
\ No newline at end of file
diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py
index 510bec43e..0e2def3cb 100644
--- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py
+++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py
@@ -52,7 +52,7 @@ def get_response_body(response: httpx.Response):
return response.text
-async def set_env_variables_in_header(custom_headers: dict):
+async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]:
"""
checks if any headers on config.yaml are defined as os.environ/COHERE_API_KEY etc
@@ -62,6 +62,8 @@ async def set_env_variables_in_header(custom_headers: dict):
{"Authorization": "bearer os.environ/COHERE_API_KEY"}
"""
+ if custom_headers is None:
+ return None
headers = {}
for key, value in custom_headers.items():
# langfuse Api requires base64 encoded headers - it's simpleer to just ask litellm users to set their langfuse public and secret keys
diff --git a/litellm/router.py b/litellm/router.py
index 537c14ddc..23880025e 100644
--- a/litellm/router.py
+++ b/litellm/router.py
@@ -32,6 +32,8 @@ from openai import AsyncOpenAI
from typing_extensions import overload
import litellm
+import litellm.litellm_core_utils
+import litellm.litellm_core_utils.exception_mapping_utils
from litellm import get_secret_str
from litellm._logging import verbose_router_logger
from litellm.assistants.main import AssistantDeleted
@@ -3661,9 +3663,10 @@ class Router:
kwargs.get("litellm_params", {}).get("metadata", None)
_model_info = kwargs.get("litellm_params", {}).get("model_info", {})
- exception_headers = litellm.utils._get_litellm_response_headers(
+ exception_headers = litellm.litellm_core_utils.exception_mapping_utils._get_response_headers(
original_exception=exception
)
+
_time_to_cooldown = kwargs.get("litellm_params", {}).get(
"cooldown_time", self.cooldown_time
)
diff --git a/litellm/types/utils.py b/litellm/types/utils.py
index c3118b453..2a36dd84d 100644
--- a/litellm/types/utils.py
+++ b/litellm/types/utils.py
@@ -1418,3 +1418,17 @@ class StandardCallbackDynamicParams(TypedDict, total=False):
# GCS dynamic params
gcs_bucket_name: Optional[str]
gcs_path_service_account: Optional[str]
+
+
+CachingSupportedCallTypes = Literal[
+ "completion",
+ "acompletion",
+ "embedding",
+ "aembedding",
+ "atranscription",
+ "transcription",
+ "atext_completion",
+ "text_completion",
+ "arerank",
+ "rerank",
+]
diff --git a/litellm/utils.py b/litellm/utils.py
index 9efde1be7..9524838d3 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -60,7 +60,6 @@ from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.exception_mapping_utils import (
- _get_litellm_response_headers,
_get_response_headers,
exception_type,
get_error_message,
@@ -82,6 +81,7 @@ from litellm.types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
+from litellm.types.rerank import RerankResponse
from litellm.types.utils import FileTypes # type: ignore
from litellm.types.utils import (
OPENAI_RESPONSE_HEADERS,
@@ -720,6 +720,7 @@ def client(original_function):
or kwargs.get("atext_completion", False) is True
or kwargs.get("atranscription", False) is True
or kwargs.get("arerank", False) is True
+ or kwargs.get("_arealtime", False) is True
):
# [OPTIONAL] CHECK MAX RETRIES / REQUEST
if litellm.num_retries_per_request is not None:
@@ -819,6 +820,8 @@ def client(original_function):
and kwargs.get("acompletion", False) is not True
and kwargs.get("aimg_generation", False) is not True
and kwargs.get("atranscription", False) is not True
+ and kwargs.get("arerank", False) is not True
+ and kwargs.get("_arealtime", False) is not True
): # allow users to control returning cached responses from the completion function
# checking cache
print_verbose("INSIDE CHECKING CACHE")
@@ -835,7 +838,6 @@ def client(original_function):
)
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result is not None:
- print_verbose("Cache Hit!")
if "detail" in cached_result:
# implies an error occurred
pass
@@ -867,7 +869,13 @@ def client(original_function):
response_object=cached_result,
response_type="embedding",
)
-
+ elif call_type == CallTypes.rerank.value and isinstance(
+ cached_result, dict
+ ):
+ cached_result = convert_to_model_response_object(
+ response_object=cached_result,
+ response_type="rerank",
+ )
# LOG SUCCESS
cache_hit = True
end_time = datetime.datetime.now()
@@ -916,6 +924,12 @@ def client(original_function):
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
+ cache_key = kwargs.get("preset_cache_key", None)
+ if (
+ isinstance(cached_result, BaseModel)
+ or isinstance(cached_result, CustomStreamWrapper)
+ ) and hasattr(cached_result, "_hidden_params"):
+ cached_result._hidden_params["cache_key"] = cache_key # type: ignore
return cached_result
else:
print_verbose(
@@ -991,8 +1005,7 @@ def client(original_function):
if (
litellm.cache is not None
and litellm.cache.supported_call_types is not None
- and str(original_function.__name__)
- in litellm.cache.supported_call_types
+ and call_type in litellm.cache.supported_call_types
) and (kwargs.get("cache", {}).get("no-store", False) is not True):
litellm.cache.add_cache(result, *args, **kwargs)
@@ -1257,6 +1270,14 @@ def client(original_function):
model_response_object=EmbeddingResponse(),
response_type="embedding",
)
+ elif call_type == CallTypes.arerank.value and isinstance(
+ cached_result, dict
+ ):
+ cached_result = convert_to_model_response_object(
+ response_object=cached_result,
+ model_response_object=None,
+ response_type="rerank",
+ )
elif call_type == CallTypes.atranscription.value and isinstance(
cached_result, dict
):
@@ -1460,6 +1481,7 @@ def client(original_function):
isinstance(result, litellm.ModelResponse)
or isinstance(result, litellm.EmbeddingResponse)
or isinstance(result, TranscriptionResponse)
+ or isinstance(result, RerankResponse)
):
if (
isinstance(result, EmbeddingResponse)
@@ -5880,10 +5902,16 @@ def convert_to_streaming_response(response_object: Optional[dict] = None):
def convert_to_model_response_object(
response_object: Optional[dict] = None,
model_response_object: Optional[
- Union[ModelResponse, EmbeddingResponse, ImageResponse, TranscriptionResponse]
+ Union[
+ ModelResponse,
+ EmbeddingResponse,
+ ImageResponse,
+ TranscriptionResponse,
+ RerankResponse,
+ ]
] = None,
response_type: Literal[
- "completion", "embedding", "image_generation", "audio_transcription"
+ "completion", "embedding", "image_generation", "audio_transcription", "rerank"
] = "completion",
stream=False,
start_time=None,
@@ -6133,6 +6161,27 @@ def convert_to_model_response_object(
if _response_headers is not None:
model_response_object._response_headers = _response_headers
+ return model_response_object
+ elif response_type == "rerank" and (
+ model_response_object is None
+ or isinstance(model_response_object, RerankResponse)
+ ):
+ if response_object is None:
+ raise Exception("Error in response object format")
+
+ if model_response_object is None:
+ model_response_object = RerankResponse(**response_object)
+ return model_response_object
+
+ if "id" in response_object:
+ model_response_object.id = response_object["id"]
+
+ if "meta" in response_object:
+ model_response_object.meta = response_object["meta"]
+
+ if "results" in response_object:
+ model_response_object.results = response_object["results"]
+
return model_response_object
except Exception:
raise Exception(
diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json
index c1a5e6b67..962c629d7 100644
--- a/model_prices_and_context_window.json
+++ b/model_prices_and_context_window.json
@@ -5075,6 +5075,116 @@
"supports_function_calling": true,
"supports_tool_choice": false
},
+ "meta.llama3-2-1b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.0000001,
+ "output_cost_per_token": 0.0000001,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
+ "us.meta.llama3-2-1b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.0000001,
+ "output_cost_per_token": 0.0000001,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
+ "eu.meta.llama3-2-1b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.00000013,
+ "output_cost_per_token": 0.00000013,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
+ "meta.llama3-2-3b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.00000015,
+ "output_cost_per_token": 0.00000015,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
+ "us.meta.llama3-2-3b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.00000015,
+ "output_cost_per_token": 0.00000015,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
+ "eu.meta.llama3-2-3b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.00000019,
+ "output_cost_per_token": 0.00000019,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
+ "meta.llama3-2-11b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.00000035,
+ "output_cost_per_token": 0.00000035,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
+ "us.meta.llama3-2-11b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.00000035,
+ "output_cost_per_token": 0.00000035,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
+ "meta.llama3-2-90b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.000002,
+ "output_cost_per_token": 0.000002,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
+ "us.meta.llama3-2-90b-instruct-v1:0": {
+ "max_tokens": 128000,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.000002,
+ "output_cost_per_token": 0.000002,
+ "litellm_provider": "bedrock",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_tool_choice": false
+ },
"512-x-512/50-steps/stability.stable-diffusion-xl-v0": {
"max_tokens": 77,
"max_input_tokens": 77,
diff --git a/tests/local_testing/test_caching.py b/tests/local_testing/test_caching.py
index a98b47603..8ba788a55 100644
--- a/tests/local_testing/test_caching.py
+++ b/tests/local_testing/test_caching.py
@@ -5,6 +5,7 @@ import traceback
import uuid
from dotenv import load_dotenv
+from test_rerank import assert_response_shape
load_dotenv()
import os
@@ -2234,3 +2235,56 @@ def test_logging_turn_off_message_logging_streaming():
mock_client.assert_called_once()
assert mock_client.call_args.args[0].choices[0].message.content == "hello"
+
+
+@pytest.mark.asyncio()
+@pytest.mark.parametrize("sync_mode", [True, False])
+@pytest.mark.parametrize(
+ "top_n_1, top_n_2, expect_cache_hit",
+ [
+ (3, 3, True),
+ (3, None, False),
+ ],
+)
+async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit):
+ litellm.set_verbose = True
+ litellm.cache = Cache(type="local")
+
+ if sync_mode is True:
+ for idx in range(2):
+ if idx == 0:
+ top_n = top_n_1
+ else:
+ top_n = top_n_2
+ response = litellm.rerank(
+ model="cohere/rerank-english-v3.0",
+ query="hello",
+ documents=["hello", "world"],
+ top_n=top_n,
+ )
+ else:
+ for idx in range(2):
+ if idx == 0:
+ top_n = top_n_1
+ else:
+ top_n = top_n_2
+ response = await litellm.arerank(
+ model="cohere/rerank-english-v3.0",
+ query="hello",
+ documents=["hello", "world"],
+ top_n=top_n,
+ )
+
+ await asyncio.sleep(1)
+
+ if expect_cache_hit is True:
+ assert "cache_key" in response._hidden_params
+ else:
+ assert "cache_key" not in response._hidden_params
+
+ print("re rank response: ", response)
+
+ assert response.id is not None
+ assert response.results is not None
+
+ assert_response_shape(response, custom_llm_provider="cohere")
diff --git a/tests/local_testing/test_custom_callback_input.py b/tests/local_testing/test_custom_callback_input.py
index 384b4b6fd..c079123e7 100644
--- a/tests/local_testing/test_custom_callback_input.py
+++ b/tests/local_testing/test_custom_callback_input.py
@@ -1385,9 +1385,9 @@ def test_logging_standard_payload_failure_call():
resp = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
- mock_response="litellm.RateLimitError",
+ api_key="my-bad-api-key",
)
- except litellm.RateLimitError:
+ except litellm.AuthenticationError:
pass
mock_client.assert_called_once()
@@ -1401,6 +1401,7 @@ def test_logging_standard_payload_failure_call():
standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
"kwargs"
]["standard_logging_object"]
+ assert "additional_headers" in standard_logging_object["hidden_params"]
@pytest.mark.parametrize("stream", [True, False])
diff --git a/tests/local_testing/test_custom_llm.py b/tests/local_testing/test_custom_llm.py
index 29daef481..f21b27c43 100644
--- a/tests/local_testing/test_custom_llm.py
+++ b/tests/local_testing/test_custom_llm.py
@@ -368,7 +368,7 @@ async def test_simple_image_generation_async():
@pytest.mark.asyncio
-async def test_image_generation_async_with_api_key_and_api_base():
+async def test_image_generation_async_additional_params():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
@@ -383,6 +383,7 @@ async def test_image_generation_async_with_api_key_and_api_base():
prompt="Hello world",
api_key="my-api-key",
api_base="my-api-base",
+ my_custom_param="my-custom-param",
)
print(resp)
@@ -393,3 +394,6 @@ async def test_image_generation_async_with_api_key_and_api_base():
mock_client.call_args.kwargs["api_key"] == "my-api-key"
mock_client.call_args.kwargs["api_base"] == "my-api-base"
+ mock_client.call_args.kwargs["optional_params"] == {
+ "my_custom_param": "my-custom-param"
+ }
diff --git a/tests/local_testing/test_pass_through_endpoints.py b/tests/local_testing/test_pass_through_endpoints.py
index 28e6acda9..b3977e936 100644
--- a/tests/local_testing/test_pass_through_endpoints.py
+++ b/tests/local_testing/test_pass_through_endpoints.py
@@ -39,6 +39,36 @@ def client():
return TestClient(app)
+@pytest.mark.asyncio
+async def test_pass_through_endpoint_no_headers(client, monkeypatch):
+ # Mock the httpx.AsyncClient.request method
+ monkeypatch.setattr("httpx.AsyncClient.request", mock_request)
+ import litellm
+
+ # Define a pass-through endpoint
+ pass_through_endpoints = [
+ {
+ "path": "/test-endpoint",
+ "target": "https://api.example.com/v1/chat/completions",
+ }
+ ]
+
+ # Initialize the pass-through endpoint
+ await initialize_pass_through_endpoints(pass_through_endpoints)
+ general_settings: dict = (
+ getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
+ )
+ general_settings.update({"pass_through_endpoints": pass_through_endpoints})
+ setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
+
+ # Make a request to the pass-through endpoint
+ response = client.post("/test-endpoint", json={"prompt": "Hello, world!"})
+
+ # Assert the response
+ assert response.status_code == 200
+ assert response.json() == {"message": "Mocked response"}
+
+
@pytest.mark.asyncio
async def test_pass_through_endpoint(client, monkeypatch):
# Mock the httpx.AsyncClient.request method