From 2acb0c0675d8758273d47b71e2c0c065b43f501d Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Sat, 12 Oct 2024 11:48:34 -0700 Subject: [PATCH] Litellm Minor Fixes & Improvements (10/12/2024) (#6179) * build(model_prices_and_context_window.json): add bedrock llama3.2 pricing * build(model_prices_and_context_window.json): add bedrock cross region inference pricing * Revert "(perf) move s3 logging to Batch logging + async [94% faster perf under 100 RPS on 1 litellm instance] (#6165)" This reverts commit 2a5624af471284f174e084142504d950ede2567d. * add azure/gpt-4o-2024-05-13 (#6174) * LiteLLM Minor Fixes & Improvements (10/10/2024) (#6158) * refactor(vertex_ai_partner_models/anthropic): refactor anthropic to use partner model logic * fix(vertex_ai/): support passing custom api base to partner models Fixes https://github.com/BerriAI/litellm/issues/4317 * fix(proxy_server.py): Fix prometheus premium user check logic * docs(prometheus.md): update quick start docs * fix(custom_llm.py): support passing dynamic api key + api base * fix(realtime_api/main.py): Add request/response logging for realtime api endpoints Closes https://github.com/BerriAI/litellm/issues/6081 * feat(openai/realtime): add openai realtime api logging Closes https://github.com/BerriAI/litellm/issues/6081 * fix(realtime_streaming.py): fix linting errors * fix(realtime_streaming.py): fix linting errors * fix: fix linting errors * fix pattern match router * Add literalai in the sidebar observability category (#6163) * fix: add literalai in the sidebar * fix: typo * update (#6160) * Feat: Add Langtrace integration (#5341) * Feat: Add Langtrace integration * add langtrace service name * fix timestamps for traces * add tests * Discard Callback + use existing otel logger * cleanup * remove print statments * remove callback * add docs * docs * add logging docs * format logging * remove emoji and add litellm proxy example * format logging * format `logging.md` * add langtrace docs to logging.md * sync conflict * docs fix * (perf) move s3 logging to Batch logging + async [94% faster perf under 100 RPS on 1 litellm instance] (#6165) * fix move s3 to use customLogger * add basic s3 logging test * add s3 to custom logger compatible * use batch logger for s3 * s3 set flush interval and batch size * fix s3 logging * add notes on s3 logging * fix s3 logging * add basic s3 logging test * fix s3 type errors * add test for sync logging on s3 * fix: fix to debug log --------- Co-authored-by: Ishaan Jaff Co-authored-by: Willy Douhard Co-authored-by: yujonglee Co-authored-by: Ali Waleed * docs(custom_llm_server.md): update doc on passing custom params * fix(pass_through_endpoints.py): don't require headers Fixes https://github.com/BerriAI/litellm/issues/6128 * feat(utils.py): add support for caching rerank endpoints Closes https://github.com/BerriAI/litellm/issues/6144 * feat(litellm_logging.py'): add response headers for failed requests Closes https://github.com/BerriAI/litellm/issues/6159 --------- Co-authored-by: Ishaan Jaff Co-authored-by: Willy Douhard Co-authored-by: yujonglee Co-authored-by: Ali Waleed --- .../docs/providers/custom_llm_server.md | 99 ++++++++++++++++ litellm/caching.py | 70 ++++------- litellm/integrations/langfuse.py | 1 - .../exception_mapping_utils.py | 25 +--- litellm/litellm_core_utils/litellm_logging.py | 16 +++ ...odel_prices_and_context_window_backup.json | 110 ++++++++++++++++++ litellm/proxy/_experimental/out/404.html | 1 - .../proxy/_experimental/out/model_hub.html | 1 - .../proxy/_experimental/out/onboarding.html | 1 - .../pass_through_endpoints.py | 4 +- litellm/router.py | 5 +- litellm/types/utils.py | 14 +++ litellm/utils.py | 63 ++++++++-- model_prices_and_context_window.json | 110 ++++++++++++++++++ tests/local_testing/test_caching.py | 54 +++++++++ .../test_custom_callback_input.py | 5 +- tests/local_testing/test_custom_llm.py | 6 +- .../test_pass_through_endpoints.py | 30 +++++ 18 files changed, 533 insertions(+), 82 deletions(-) delete mode 100644 litellm/proxy/_experimental/out/404.html delete mode 100644 litellm/proxy/_experimental/out/model_hub.html delete mode 100644 litellm/proxy/_experimental/out/onboarding.html diff --git a/docs/my-website/docs/providers/custom_llm_server.md b/docs/my-website/docs/providers/custom_llm_server.md index 6d2015010..2adb6a67c 100644 --- a/docs/my-website/docs/providers/custom_llm_server.md +++ b/docs/my-website/docs/providers/custom_llm_server.md @@ -251,6 +251,105 @@ Expected Response } ``` +## Additional Parameters + +Additional parameters are passed inside `optional_params` key in the `completion` or `image_generation` function. + +Here's how to set this: + + + + +```python +import litellm +from litellm import CustomLLM, completion, get_llm_provider + + +class MyCustomLLM(CustomLLM): + def completion(self, *args, **kwargs) -> litellm.ModelResponse: + assert kwargs["optional_params"] == {"my_custom_param": "my-custom-param"} # 👈 CHECK HERE + return litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello world"}], + mock_response="Hi!", + ) # type: ignore + +my_custom_llm = MyCustomLLM() + +litellm.custom_provider_map = [ # 👈 KEY STEP - REGISTER HANDLER + {"provider": "my-custom-llm", "custom_handler": my_custom_llm} + ] + +resp = completion(model="my-custom-llm/my-model", my_custom_param="my-custom-param") +``` + + + + + +1. Setup your `custom_handler.py` file +```python +import litellm +from litellm import CustomLLM +from litellm.types.utils import ImageResponse, ImageObject + + +class MyCustomLLM(CustomLLM): + async def aimage_generation(self, model: str, prompt: str, model_response: ImageResponse, optional_params: dict, logging_obj: Any, timeout: Optional[Union[float, httpx.Timeout]] = None, client: Optional[AsyncHTTPHandler] = None,) -> ImageResponse: + assert optional_params == {"my_custom_param": "my-custom-param"} # 👈 CHECK HERE + return ImageResponse( + created=int(time.time()), + data=[ImageObject(url="https://example.com/image.png")], + ) + +my_custom_llm = MyCustomLLM() +``` + + +2. Add to `config.yaml` + +In the config below, we pass + +python_filename: `custom_handler.py` +custom_handler_instance_name: `my_custom_llm`. This is defined in Step 1 + +custom_handler: `custom_handler.my_custom_llm` + +```yaml +model_list: + - model_name: "test-model" + litellm_params: + model: "openai/text-embedding-ada-002" + - model_name: "my-custom-model" + litellm_params: + model: "my-custom-llm/my-model" + my_custom_param: "my-custom-param" # 👈 CUSTOM PARAM + +litellm_settings: + custom_provider_map: + - {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm} +``` + +```bash +litellm --config /path/to/config.yaml +``` + +3. Test it! + +```bash +curl -X POST 'http://0.0.0.0:4000/v1/images/generations' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-d '{ + "model": "my-custom-model", + "prompt": "A cute baby sea otter", +}' +``` + + + + + ## Custom Handler Spec diff --git a/litellm/caching.py b/litellm/caching.py index 91d9e6996..c9767b624 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -20,13 +20,13 @@ from datetime import timedelta from enum import Enum from typing import Any, List, Literal, Optional, Tuple, Union -from openai._models import BaseModel as OpenAIObject +from pydantic import BaseModel import litellm from litellm._logging import verbose_logger from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.types.services import ServiceLoggerPayload, ServiceTypes -from litellm.types.utils import all_litellm_params +from litellm.types.utils import CachingSupportedCallTypes, all_litellm_params def print_verbose(print_statement): @@ -2139,20 +2139,7 @@ class Cache: default_in_memory_ttl: Optional[float] = None, default_in_redis_ttl: Optional[float] = None, similarity_threshold: Optional[float] = None, - supported_call_types: Optional[ - List[ - Literal[ - "completion", - "acompletion", - "embedding", - "aembedding", - "atranscription", - "transcription", - "atext_completion", - "text_completion", - ] - ] - ] = [ + supported_call_types: Optional[List[CachingSupportedCallTypes]] = [ "completion", "acompletion", "embedding", @@ -2161,6 +2148,8 @@ class Cache: "transcription", "atext_completion", "text_completion", + "arerank", + "rerank", ], # s3 Bucket, boto3 configuration s3_bucket_name: Optional[str] = None, @@ -2353,9 +2342,20 @@ class Cache: "file", "language", ] + rerank_only_kwargs = [ + "top_n", + "rank_fields", + "return_documents", + "max_chunks_per_doc", + "documents", + "query", + ] # combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set() combined_kwargs = ( - completion_kwargs + embedding_only_kwargs + transcription_only_kwargs + completion_kwargs + + embedding_only_kwargs + + transcription_only_kwargs + + rerank_only_kwargs ) litellm_param_kwargs = all_litellm_params for param in kwargs: @@ -2557,7 +2557,7 @@ class Cache: else: cache_key = self.get_cache_key(*args, **kwargs) if cache_key is not None: - if isinstance(result, OpenAIObject): + if isinstance(result, BaseModel): result = result.model_dump_json() ## DEFAULT TTL ## @@ -2778,20 +2778,7 @@ def enable_cache( host: Optional[str] = None, port: Optional[str] = None, password: Optional[str] = None, - supported_call_types: Optional[ - List[ - Literal[ - "completion", - "acompletion", - "embedding", - "aembedding", - "atranscription", - "transcription", - "atext_completion", - "text_completion", - ] - ] - ] = [ + supported_call_types: Optional[List[CachingSupportedCallTypes]] = [ "completion", "acompletion", "embedding", @@ -2800,6 +2787,8 @@ def enable_cache( "transcription", "atext_completion", "text_completion", + "arerank", + "rerank", ], **kwargs, ): @@ -2847,20 +2836,7 @@ def update_cache( host: Optional[str] = None, port: Optional[str] = None, password: Optional[str] = None, - supported_call_types: Optional[ - List[ - Literal[ - "completion", - "acompletion", - "embedding", - "aembedding", - "atranscription", - "transcription", - "atext_completion", - "text_completion", - ] - ] - ] = [ + supported_call_types: Optional[List[CachingSupportedCallTypes]] = [ "completion", "acompletion", "embedding", @@ -2869,6 +2845,8 @@ def update_cache( "transcription", "atext_completion", "text_completion", + "arerank", + "rerank", ], **kwargs, ): diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py index 177ff735c..c79b43422 100644 --- a/litellm/integrations/langfuse.py +++ b/litellm/integrations/langfuse.py @@ -191,7 +191,6 @@ class LangFuseLogger: pass # end of processing langfuse ######################## - if ( level == "ERROR" and status_message is not None diff --git a/litellm/litellm_core_utils/exception_mapping_utils.py b/litellm/litellm_core_utils/exception_mapping_utils.py index 61cca6e07..2572e695f 100644 --- a/litellm/litellm_core_utils/exception_mapping_utils.py +++ b/litellm/litellm_core_utils/exception_mapping_utils.py @@ -67,25 +67,6 @@ def get_error_message(error_obj) -> Optional[str]: ####### EXCEPTION MAPPING ################ -def _get_litellm_response_headers( - original_exception: Exception, -) -> Optional[httpx.Headers]: - """ - Extract and return the response headers from a mapped exception, if present. - - Used for accurate retry logic. - """ - _response_headers: Optional[httpx.Headers] = None - try: - _response_headers = getattr( - original_exception, "litellm_response_headers", None - ) - except Exception: - return None - - return _response_headers - - def _get_response_headers(original_exception: Exception) -> Optional[httpx.Headers]: """ Extract and return the response headers from an exception, if present. @@ -96,8 +77,12 @@ def _get_response_headers(original_exception: Exception) -> Optional[httpx.Heade try: _response_headers = getattr(original_exception, "headers", None) error_response = getattr(original_exception, "response", None) - if _response_headers is None and error_response: + if not _response_headers and error_response: _response_headers = getattr(error_response, "headers", None) + if not _response_headers: + _response_headers = getattr( + original_exception, "litellm_response_headers", None + ) except Exception: return None diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index a641be019..d3f15e6bc 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -84,6 +84,7 @@ from ..integrations.s3 import S3Logger from ..integrations.supabase import Supabase from ..integrations.traceloop import TraceloopLogger from ..integrations.weights_biases import WeightsBiasesLogger +from .exception_mapping_utils import _get_response_headers try: from ..proxy.enterprise.enterprise_callbacks.generic_api_callback import ( @@ -1813,6 +1814,7 @@ class Logging: logging_obj=self, status="failure", error_str=str(exception), + original_exception=exception, ) ) return start_time, end_time @@ -2654,6 +2656,7 @@ def get_standard_logging_object_payload( logging_obj: Logging, status: StandardLoggingPayloadStatus, error_str: Optional[str] = None, + original_exception: Optional[Exception] = None, ) -> Optional[StandardLoggingPayload]: try: if kwargs is None: @@ -2670,6 +2673,19 @@ def get_standard_logging_object_payload( else: response_obj = {} + if original_exception is not None and hidden_params is None: + response_headers = _get_response_headers(original_exception) + if response_headers is not None: + hidden_params = dict( + StandardLoggingHiddenParams( + additional_headers=dict(response_headers), + model_id=None, + cache_key=None, + api_base=None, + response_cost=None, + ) + ) + # standardize this function to be used across, s3, dynamoDB, langfuse logging litellm_params = kwargs.get("litellm_params", {}) proxy_server_request = litellm_params.get("proxy_server_request") or {} diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index c1a5e6b67..962c629d7 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -5075,6 +5075,116 @@ "supports_function_calling": true, "supports_tool_choice": false }, + "meta.llama3-2-1b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.0000001, + "output_cost_per_token": 0.0000001, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, + "us.meta.llama3-2-1b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.0000001, + "output_cost_per_token": 0.0000001, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, + "eu.meta.llama3-2-1b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.00000013, + "output_cost_per_token": 0.00000013, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, + "meta.llama3-2-3b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.00000015, + "output_cost_per_token": 0.00000015, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, + "us.meta.llama3-2-3b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.00000015, + "output_cost_per_token": 0.00000015, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, + "eu.meta.llama3-2-3b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.00000019, + "output_cost_per_token": 0.00000019, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, + "meta.llama3-2-11b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.00000035, + "output_cost_per_token": 0.00000035, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, + "us.meta.llama3-2-11b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.00000035, + "output_cost_per_token": 0.00000035, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, + "meta.llama3-2-90b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.000002, + "output_cost_per_token": 0.000002, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, + "us.meta.llama3-2-90b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.000002, + "output_cost_per_token": 0.000002, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, "512-x-512/50-steps/stability.stable-diffusion-xl-v0": { "max_tokens": 77, "max_input_tokens": 77, diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404.html deleted file mode 100644 index 3387c26ce..000000000 --- a/litellm/proxy/_experimental/out/404.html +++ /dev/null @@ -1 +0,0 @@ -404: This page could not be found.LiteLLM Dashboard

404

This page could not be found.

\ No newline at end of file diff --git a/litellm/proxy/_experimental/out/model_hub.html b/litellm/proxy/_experimental/out/model_hub.html deleted file mode 100644 index bc65f3d70..000000000 --- a/litellm/proxy/_experimental/out/model_hub.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html deleted file mode 100644 index 9ee6afdcb..000000000 --- a/litellm/proxy/_experimental/out/onboarding.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 510bec43e..0e2def3cb 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -52,7 +52,7 @@ def get_response_body(response: httpx.Response): return response.text -async def set_env_variables_in_header(custom_headers: dict): +async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]: """ checks if any headers on config.yaml are defined as os.environ/COHERE_API_KEY etc @@ -62,6 +62,8 @@ async def set_env_variables_in_header(custom_headers: dict): {"Authorization": "bearer os.environ/COHERE_API_KEY"} """ + if custom_headers is None: + return None headers = {} for key, value in custom_headers.items(): # langfuse Api requires base64 encoded headers - it's simpleer to just ask litellm users to set their langfuse public and secret keys diff --git a/litellm/router.py b/litellm/router.py index 537c14ddc..23880025e 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -32,6 +32,8 @@ from openai import AsyncOpenAI from typing_extensions import overload import litellm +import litellm.litellm_core_utils +import litellm.litellm_core_utils.exception_mapping_utils from litellm import get_secret_str from litellm._logging import verbose_router_logger from litellm.assistants.main import AssistantDeleted @@ -3661,9 +3663,10 @@ class Router: kwargs.get("litellm_params", {}).get("metadata", None) _model_info = kwargs.get("litellm_params", {}).get("model_info", {}) - exception_headers = litellm.utils._get_litellm_response_headers( + exception_headers = litellm.litellm_core_utils.exception_mapping_utils._get_response_headers( original_exception=exception ) + _time_to_cooldown = kwargs.get("litellm_params", {}).get( "cooldown_time", self.cooldown_time ) diff --git a/litellm/types/utils.py b/litellm/types/utils.py index c3118b453..2a36dd84d 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1418,3 +1418,17 @@ class StandardCallbackDynamicParams(TypedDict, total=False): # GCS dynamic params gcs_bucket_name: Optional[str] gcs_path_service_account: Optional[str] + + +CachingSupportedCallTypes = Literal[ + "completion", + "acompletion", + "embedding", + "aembedding", + "atranscription", + "transcription", + "atext_completion", + "text_completion", + "arerank", + "rerank", +] diff --git a/litellm/utils.py b/litellm/utils.py index 9efde1be7..9524838d3 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -60,7 +60,6 @@ from litellm.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.exception_mapping_utils import ( - _get_litellm_response_headers, _get_response_headers, exception_type, get_error_message, @@ -82,6 +81,7 @@ from litellm.types.llms.openai import ( ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk, ) +from litellm.types.rerank import RerankResponse from litellm.types.utils import FileTypes # type: ignore from litellm.types.utils import ( OPENAI_RESPONSE_HEADERS, @@ -720,6 +720,7 @@ def client(original_function): or kwargs.get("atext_completion", False) is True or kwargs.get("atranscription", False) is True or kwargs.get("arerank", False) is True + or kwargs.get("_arealtime", False) is True ): # [OPTIONAL] CHECK MAX RETRIES / REQUEST if litellm.num_retries_per_request is not None: @@ -819,6 +820,8 @@ def client(original_function): and kwargs.get("acompletion", False) is not True and kwargs.get("aimg_generation", False) is not True and kwargs.get("atranscription", False) is not True + and kwargs.get("arerank", False) is not True + and kwargs.get("_arealtime", False) is not True ): # allow users to control returning cached responses from the completion function # checking cache print_verbose("INSIDE CHECKING CACHE") @@ -835,7 +838,6 @@ def client(original_function): ) cached_result = litellm.cache.get_cache(*args, **kwargs) if cached_result is not None: - print_verbose("Cache Hit!") if "detail" in cached_result: # implies an error occurred pass @@ -867,7 +869,13 @@ def client(original_function): response_object=cached_result, response_type="embedding", ) - + elif call_type == CallTypes.rerank.value and isinstance( + cached_result, dict + ): + cached_result = convert_to_model_response_object( + response_object=cached_result, + response_type="rerank", + ) # LOG SUCCESS cache_hit = True end_time = datetime.datetime.now() @@ -916,6 +924,12 @@ def client(original_function): target=logging_obj.success_handler, args=(cached_result, start_time, end_time, cache_hit), ).start() + cache_key = kwargs.get("preset_cache_key", None) + if ( + isinstance(cached_result, BaseModel) + or isinstance(cached_result, CustomStreamWrapper) + ) and hasattr(cached_result, "_hidden_params"): + cached_result._hidden_params["cache_key"] = cache_key # type: ignore return cached_result else: print_verbose( @@ -991,8 +1005,7 @@ def client(original_function): if ( litellm.cache is not None and litellm.cache.supported_call_types is not None - and str(original_function.__name__) - in litellm.cache.supported_call_types + and call_type in litellm.cache.supported_call_types ) and (kwargs.get("cache", {}).get("no-store", False) is not True): litellm.cache.add_cache(result, *args, **kwargs) @@ -1257,6 +1270,14 @@ def client(original_function): model_response_object=EmbeddingResponse(), response_type="embedding", ) + elif call_type == CallTypes.arerank.value and isinstance( + cached_result, dict + ): + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=None, + response_type="rerank", + ) elif call_type == CallTypes.atranscription.value and isinstance( cached_result, dict ): @@ -1460,6 +1481,7 @@ def client(original_function): isinstance(result, litellm.ModelResponse) or isinstance(result, litellm.EmbeddingResponse) or isinstance(result, TranscriptionResponse) + or isinstance(result, RerankResponse) ): if ( isinstance(result, EmbeddingResponse) @@ -5880,10 +5902,16 @@ def convert_to_streaming_response(response_object: Optional[dict] = None): def convert_to_model_response_object( response_object: Optional[dict] = None, model_response_object: Optional[ - Union[ModelResponse, EmbeddingResponse, ImageResponse, TranscriptionResponse] + Union[ + ModelResponse, + EmbeddingResponse, + ImageResponse, + TranscriptionResponse, + RerankResponse, + ] ] = None, response_type: Literal[ - "completion", "embedding", "image_generation", "audio_transcription" + "completion", "embedding", "image_generation", "audio_transcription", "rerank" ] = "completion", stream=False, start_time=None, @@ -6133,6 +6161,27 @@ def convert_to_model_response_object( if _response_headers is not None: model_response_object._response_headers = _response_headers + return model_response_object + elif response_type == "rerank" and ( + model_response_object is None + or isinstance(model_response_object, RerankResponse) + ): + if response_object is None: + raise Exception("Error in response object format") + + if model_response_object is None: + model_response_object = RerankResponse(**response_object) + return model_response_object + + if "id" in response_object: + model_response_object.id = response_object["id"] + + if "meta" in response_object: + model_response_object.meta = response_object["meta"] + + if "results" in response_object: + model_response_object.results = response_object["results"] + return model_response_object except Exception: raise Exception( diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index c1a5e6b67..962c629d7 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -5075,6 +5075,116 @@ "supports_function_calling": true, "supports_tool_choice": false }, + "meta.llama3-2-1b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.0000001, + "output_cost_per_token": 0.0000001, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, + "us.meta.llama3-2-1b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.0000001, + "output_cost_per_token": 0.0000001, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, + "eu.meta.llama3-2-1b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.00000013, + "output_cost_per_token": 0.00000013, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, + "meta.llama3-2-3b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.00000015, + "output_cost_per_token": 0.00000015, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, + "us.meta.llama3-2-3b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.00000015, + "output_cost_per_token": 0.00000015, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, + "eu.meta.llama3-2-3b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.00000019, + "output_cost_per_token": 0.00000019, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, + "meta.llama3-2-11b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.00000035, + "output_cost_per_token": 0.00000035, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, + "us.meta.llama3-2-11b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.00000035, + "output_cost_per_token": 0.00000035, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, + "meta.llama3-2-90b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.000002, + "output_cost_per_token": 0.000002, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, + "us.meta.llama3-2-90b-instruct-v1:0": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.000002, + "output_cost_per_token": 0.000002, + "litellm_provider": "bedrock", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": false + }, "512-x-512/50-steps/stability.stable-diffusion-xl-v0": { "max_tokens": 77, "max_input_tokens": 77, diff --git a/tests/local_testing/test_caching.py b/tests/local_testing/test_caching.py index a98b47603..8ba788a55 100644 --- a/tests/local_testing/test_caching.py +++ b/tests/local_testing/test_caching.py @@ -5,6 +5,7 @@ import traceback import uuid from dotenv import load_dotenv +from test_rerank import assert_response_shape load_dotenv() import os @@ -2234,3 +2235,56 @@ def test_logging_turn_off_message_logging_streaming(): mock_client.assert_called_once() assert mock_client.call_args.args[0].choices[0].message.content == "hello" + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.parametrize( + "top_n_1, top_n_2, expect_cache_hit", + [ + (3, 3, True), + (3, None, False), + ], +) +async def test_basic_rerank_caching(sync_mode, top_n_1, top_n_2, expect_cache_hit): + litellm.set_verbose = True + litellm.cache = Cache(type="local") + + if sync_mode is True: + for idx in range(2): + if idx == 0: + top_n = top_n_1 + else: + top_n = top_n_2 + response = litellm.rerank( + model="cohere/rerank-english-v3.0", + query="hello", + documents=["hello", "world"], + top_n=top_n, + ) + else: + for idx in range(2): + if idx == 0: + top_n = top_n_1 + else: + top_n = top_n_2 + response = await litellm.arerank( + model="cohere/rerank-english-v3.0", + query="hello", + documents=["hello", "world"], + top_n=top_n, + ) + + await asyncio.sleep(1) + + if expect_cache_hit is True: + assert "cache_key" in response._hidden_params + else: + assert "cache_key" not in response._hidden_params + + print("re rank response: ", response) + + assert response.id is not None + assert response.results is not None + + assert_response_shape(response, custom_llm_provider="cohere") diff --git a/tests/local_testing/test_custom_callback_input.py b/tests/local_testing/test_custom_callback_input.py index 384b4b6fd..c079123e7 100644 --- a/tests/local_testing/test_custom_callback_input.py +++ b/tests/local_testing/test_custom_callback_input.py @@ -1385,9 +1385,9 @@ def test_logging_standard_payload_failure_call(): resp = litellm.completion( model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}], - mock_response="litellm.RateLimitError", + api_key="my-bad-api-key", ) - except litellm.RateLimitError: + except litellm.AuthenticationError: pass mock_client.assert_called_once() @@ -1401,6 +1401,7 @@ def test_logging_standard_payload_failure_call(): standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[ "kwargs" ]["standard_logging_object"] + assert "additional_headers" in standard_logging_object["hidden_params"] @pytest.mark.parametrize("stream", [True, False]) diff --git a/tests/local_testing/test_custom_llm.py b/tests/local_testing/test_custom_llm.py index 29daef481..f21b27c43 100644 --- a/tests/local_testing/test_custom_llm.py +++ b/tests/local_testing/test_custom_llm.py @@ -368,7 +368,7 @@ async def test_simple_image_generation_async(): @pytest.mark.asyncio -async def test_image_generation_async_with_api_key_and_api_base(): +async def test_image_generation_async_additional_params(): my_custom_llm = MyCustomLLM() litellm.custom_provider_map = [ {"provider": "custom_llm", "custom_handler": my_custom_llm} @@ -383,6 +383,7 @@ async def test_image_generation_async_with_api_key_and_api_base(): prompt="Hello world", api_key="my-api-key", api_base="my-api-base", + my_custom_param="my-custom-param", ) print(resp) @@ -393,3 +394,6 @@ async def test_image_generation_async_with_api_key_and_api_base(): mock_client.call_args.kwargs["api_key"] == "my-api-key" mock_client.call_args.kwargs["api_base"] == "my-api-base" + mock_client.call_args.kwargs["optional_params"] == { + "my_custom_param": "my-custom-param" + } diff --git a/tests/local_testing/test_pass_through_endpoints.py b/tests/local_testing/test_pass_through_endpoints.py index 28e6acda9..b3977e936 100644 --- a/tests/local_testing/test_pass_through_endpoints.py +++ b/tests/local_testing/test_pass_through_endpoints.py @@ -39,6 +39,36 @@ def client(): return TestClient(app) +@pytest.mark.asyncio +async def test_pass_through_endpoint_no_headers(client, monkeypatch): + # Mock the httpx.AsyncClient.request method + monkeypatch.setattr("httpx.AsyncClient.request", mock_request) + import litellm + + # Define a pass-through endpoint + pass_through_endpoints = [ + { + "path": "/test-endpoint", + "target": "https://api.example.com/v1/chat/completions", + } + ] + + # Initialize the pass-through endpoint + await initialize_pass_through_endpoints(pass_through_endpoints) + general_settings: dict = ( + getattr(litellm.proxy.proxy_server, "general_settings", {}) or {} + ) + general_settings.update({"pass_through_endpoints": pass_through_endpoints}) + setattr(litellm.proxy.proxy_server, "general_settings", general_settings) + + # Make a request to the pass-through endpoint + response = client.post("/test-endpoint", json={"prompt": "Hello, world!"}) + + # Assert the response + assert response.status_code == 200 + assert response.json() == {"message": "Mocked response"} + + @pytest.mark.asyncio async def test_pass_through_endpoint(client, monkeypatch): # Mock the httpx.AsyncClient.request method