From 4d1b4beb3df41c7da0ace49ac83cbbde57edf452 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 14 Oct 2024 16:34:01 +0530 Subject: [PATCH] (refactor) caching use LLMCachingHandler for async_get_cache and set_cache (#6208) * use folder for caching * fix importing caching * fix clickhouse pyright * fix linting * fix correctly pass kwargs and args * fix test case for embedding * fix linting * fix embedding caching logic * fix refactor handle utils.py * fix test_embedding_caching_azure_individual_items_reordered --- .../test_loadtest_openai_client.py | 2 +- .../test_loadtest_router_withs3_cache.py | 2 +- docs/my-website/docs/caching/all_caches.md | 18 +- docs/my-website/docs/caching/caching_api.md | 6 +- docs/my-website/docs/caching/local_caching.md | 8 +- docs/my-website/docs/proxy/caching.md | 4 +- docs/my-website/docs/proxy/configs.md | 2 +- .../docs/proxy/guardrails/custom_guardrail.md | 2 +- .../generic_api_callback.py | 2 +- enterprise/enterprise_hooks/aporia_ai.py | 2 +- .../enterprise_hooks/banned_keywords.py | 2 +- .../enterprise_hooks/blocked_user_list.py | 2 +- .../google_text_moderation.py | 2 +- enterprise/enterprise_hooks/llama_guard.py | 2 +- enterprise/enterprise_hooks/llm_guard.py | 2 +- .../enterprise_hooks/openai_moderation.py | 2 +- .../enterprise_hooks/secret_detection.py | 2 +- litellm/__init__.py | 2 +- litellm/batch_completion/main.py | 11 +- litellm/{ => caching}/caching.py | 10 +- litellm/caching/caching_handler.py | 440 ++++++++++++++++++ .../deprecated_litellm_server/server_utils.py | 2 +- .../SlackAlerting/slack_alerting.py | 2 +- litellm/integrations/clickhouse.py | 68 ++- litellm/integrations/custom_logger.py | 2 +- litellm/litellm_core_utils/litellm_logging.py | 2 +- litellm/llms/AzureOpenAI/azure.py | 2 +- litellm/llms/base_aws_llm.py | 2 +- litellm/llms/bedrock/chat/invoke_handler.py | 2 +- .../llms/prompt_templates/image_handling.py | 2 +- .../vertex_ai_context_caching.py | 2 +- litellm/main.py | 2 +- litellm/proxy/auth/auth_checks.py | 2 +- litellm/proxy/auth/handle_jwt.py | 2 +- litellm/proxy/caching_routes.py | 2 +- .../example_config_yaml/custom_guardrail.py | 2 +- .../guardrails/guardrail_hooks/aporia_ai.py | 2 +- .../guardrail_hooks/bedrock_guardrails.py | 2 +- .../guardrail_hooks/custom_guardrail.py | 2 +- .../guardrails/guardrail_hooks/presidio.py | 2 +- .../health_endpoints/_health_endpoints.py | 2 +- litellm/proxy/hooks/azure_content_safety.py | 2 +- litellm/proxy/hooks/batch_redis_get.py | 2 +- litellm/proxy/hooks/cache_control_check.py | 2 +- litellm/proxy/hooks/dynamic_rate_limiter.py | 2 +- litellm/proxy/hooks/max_budget_limiter.py | 2 +- .../proxy/hooks/parallel_request_limiter.py | 2 +- litellm/proxy/hooks/presidio_pii_masking.py | 2 +- .../proxy/hooks/prompt_injection_detection.py | 2 +- litellm/proxy/proxy_server.py | 4 +- litellm/proxy/utils.py | 2 +- litellm/router.py | 2 +- litellm/router_strategy/least_busy.py | 2 +- litellm/router_strategy/lowest_cost.py | 2 +- litellm/router_strategy/lowest_latency.py | 2 +- litellm/router_strategy/lowest_tpm_rpm.py | 2 +- litellm/router_strategy/lowest_tpm_rpm_v2.py | 2 +- litellm/router_utils/cooldown_cache.py | 2 +- litellm/scheduler.py | 10 +- .../secret_managers/google_secret_manager.py | 2 +- litellm/secret_managers/main.py | 2 +- litellm/utils.py | 397 +++------------- tests/local_testing/test_add_update_models.py | 2 +- tests/local_testing/test_alerting.py | 2 +- tests/local_testing/test_auth_checks.py | 2 +- .../test_azure_content_safety.py | 2 +- .../local_testing/test_banned_keyword_list.py | 2 +- tests/local_testing/test_blocked_user_list.py | 4 +- tests/local_testing/test_caching.py | 23 +- tests/local_testing/test_caching_ssl.py | 2 +- tests/local_testing/test_datadog.py | 2 +- tests/local_testing/test_jwt.py | 2 +- .../local_testing/test_key_generate_prisma.py | 8 +- .../test_lakera_ai_prompt_injection.py | 2 +- .../local_testing/test_least_busy_routing.py | 2 +- tests/local_testing/test_llm_guard.py | 2 +- .../local_testing/test_load_test_router_s3.py | 2 +- .../local_testing/test_lowest_cost_routing.py | 2 +- .../test_lowest_latency_routing.py | 2 +- .../local_testing/test_max_tpm_rpm_limiter.py | 2 +- .../test_openai_moderations_hook.py | 2 +- .../test_parallel_request_limiter.py | 2 +- tests/local_testing/test_presidio_masking.py | 2 +- .../local_testing/test_prometheus_service.py | 2 +- .../test_prompt_injection_detection.py | 2 +- .../test_proxy_reject_logging.py | 2 +- tests/local_testing/test_proxy_server.py | 4 +- .../local_testing/test_secret_detect_hook.py | 2 +- tests/local_testing/test_streaming.py | 2 +- .../local_testing/test_tpm_rpm_routing_v2.py | 2 +- tests/local_testing/test_update_spend.py | 4 +- tests/local_testing/test_whisper.py | 2 +- .../test_key_management.py | 2 +- .../test_role_based_access.py | 2 +- .../proxy_admin_ui_tests/test_sso_sign_in.py | 2 +- .../test_usage_endpoints.py | 2 +- 96 files changed, 690 insertions(+), 489 deletions(-) rename litellm/{ => caching}/caching.py (99%) create mode 100644 litellm/caching/caching_handler.py diff --git a/cookbook/litellm_router_load_test/test_loadtest_openai_client.py b/cookbook/litellm_router_load_test/test_loadtest_openai_client.py index d11249cd2..63a0abd68 100644 --- a/cookbook/litellm_router_load_test/test_loadtest_openai_client.py +++ b/cookbook/litellm_router_load_test/test_loadtest_openai_client.py @@ -10,7 +10,7 @@ sys.path.insert( import asyncio from litellm import Router, Timeout import time -from litellm.caching import Cache +from litellm.caching.caching import Cache import litellm import openai diff --git a/cookbook/litellm_router_load_test/test_loadtest_router_withs3_cache.py b/cookbook/litellm_router_load_test/test_loadtest_router_withs3_cache.py index 5e6dd218c..4df8b7f5e 100644 --- a/cookbook/litellm_router_load_test/test_loadtest_router_withs3_cache.py +++ b/cookbook/litellm_router_load_test/test_loadtest_router_withs3_cache.py @@ -10,7 +10,7 @@ sys.path.insert( import asyncio from litellm import Router, Timeout import time -from litellm.caching import Cache +from litellm.caching.caching import Cache import litellm litellm.cache = Cache( diff --git a/docs/my-website/docs/caching/all_caches.md b/docs/my-website/docs/caching/all_caches.md index d6ccb98a2..dc1951cc7 100644 --- a/docs/my-website/docs/caching/all_caches.md +++ b/docs/my-website/docs/caching/all_caches.md @@ -3,7 +3,7 @@ import TabItem from '@theme/TabItem'; # Caching - In-Memory, Redis, s3, Redis Semantic Cache, Disk -[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/caching.py) +[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm.caching.caching.py) :::info @@ -31,7 +31,7 @@ For the hosted version you can setup your own Redis DB here: https://app.redisla ```python import litellm from litellm import completion -from litellm.caching import Cache +from litellm.caching.caching import Cache litellm.cache = Cache(type="redis", host=, port=, password=) @@ -68,7 +68,7 @@ AWS_SECRET_ACCESS_KEY = "WOl*****" ```python import litellm from litellm import completion -from litellm.caching import Cache +from litellm.caching.caching import Cache # pass s3-bucket name litellm.cache = Cache(type="s3", s3_bucket_name="cache-bucket-litellm", s3_region_name="us-west-2") @@ -101,7 +101,7 @@ For the hosted version you can setup your own Redis DB here: https://app.redisla ```python import litellm from litellm import completion -from litellm.caching import Cache +from litellm.caching.caching import Cache random_number = random.randint( 1, 100000 @@ -155,7 +155,7 @@ To set up a Qdrant cluster locally follow: https://qdrant.tech/documentation/qui ```python import litellm from litellm import completion -from litellm.caching import Cache +from litellm.caching.caching import Cache random_number = random.randint( 1, 100000 @@ -210,7 +210,7 @@ assert response1.id == response2.id ```python import litellm from litellm import completion -from litellm.caching import Cache +from litellm.caching.caching import Cache litellm.cache = Cache() # Make completion calls @@ -246,7 +246,7 @@ Then you can use the disk cache as follows. ```python import litellm from litellm import completion -from litellm.caching import Cache +from litellm.caching.caching import Cache litellm.cache = Cache(type="disk") # Make completion calls @@ -422,7 +422,7 @@ def custom_get_cache_key(*args, **kwargs): Set your function as litellm.cache.get_cache_key ```python -from litellm.caching import Cache +from litellm.caching.caching import Cache cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) @@ -434,7 +434,7 @@ litellm.cache = cache # set litellm.cache to your cache ## How to write custom add/get cache functions ### 1. Init Cache ```python -from litellm.caching import Cache +from litellm.caching.caching import Cache cache = Cache() ``` diff --git a/docs/my-website/docs/caching/caching_api.md b/docs/my-website/docs/caching/caching_api.md index ff31c34ea..15ae7be0f 100644 --- a/docs/my-website/docs/caching/caching_api.md +++ b/docs/my-website/docs/caching/caching_api.md @@ -6,7 +6,7 @@ Use api.litellm.ai for caching `completion()` and `embedding()` responses ```python import litellm from litellm import completion -from litellm.caching import Cache +from litellm.caching.caching import Cache litellm.cache = Cache(type="hosted") # init cache to use api.litellm.ai # Make completion calls @@ -31,7 +31,7 @@ response2 = completion( import time import litellm from litellm import completion, embedding -from litellm.caching import Cache +from litellm.caching.caching import Cache litellm.cache = Cache(type="hosted") start_time = time.time() @@ -53,7 +53,7 @@ LiteLLM can cache your streamed responses for you import litellm import time from litellm import completion -from litellm.caching import Cache +from litellm.caching.caching import Cache litellm.cache = Cache(type="hosted") diff --git a/docs/my-website/docs/caching/local_caching.md b/docs/my-website/docs/caching/local_caching.md index 81c4edcb8..8b81438df 100644 --- a/docs/my-website/docs/caching/local_caching.md +++ b/docs/my-website/docs/caching/local_caching.md @@ -13,7 +13,7 @@ Keys in the cache are `model`, the following example will lead to a cache hit ```python import litellm from litellm import completion -from litellm.caching import Cache +from litellm.caching.caching import Cache litellm.cache = Cache() # Make completion calls @@ -35,7 +35,7 @@ response2 = completion( Add custom key-value pairs to your cache. ```python -from litellm.caching import Cache +from litellm.caching.caching import Cache cache = Cache() cache.add_cache(cache_key="test-key", result="1234") @@ -50,7 +50,7 @@ LiteLLM can cache your streamed responses for you ```python import litellm from litellm import completion -from litellm.caching import Cache +from litellm.caching.caching import Cache litellm.cache = Cache() # Make completion calls @@ -77,7 +77,7 @@ Keys in the cache are `model`, the following example will lead to a cache hit import time import litellm from litellm import embedding -from litellm.caching import Cache +from litellm.caching.caching import Cache litellm.cache = Cache() start_time = time.time() diff --git a/docs/my-website/docs/proxy/caching.md b/docs/my-website/docs/proxy/caching.md index 533d1bd9f..9e3e27bf3 100644 --- a/docs/my-website/docs/proxy/caching.md +++ b/docs/my-website/docs/proxy/caching.md @@ -49,13 +49,13 @@ litellm_settings: cache: true cache_params: # set cache params for redis type: redis - namespace: "litellm_caching" + namespace: "litellm.caching.caching" ``` and keys will be stored like: ``` -litellm_caching: +litellm.caching.caching: ``` #### Redis Cluster diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index adb260541..1dcb95b6f 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -645,7 +645,7 @@ litellm_settings: host: "localhost" # The host address for the Redis cache. Required if type is "redis". port: 6379 # The port number for the Redis cache. Required if type is "redis". password: "your_password" # The password for the Redis cache. Required if type is "redis". - namespace: "litellm_caching" # namespace for redis cache + namespace: "litellm.caching.caching" # namespace for redis cache # Optional - Redis Cluster Settings redis_startup_nodes: [{"host": "127.0.0.1", "port": "7001"}] diff --git a/docs/my-website/docs/proxy/guardrails/custom_guardrail.md b/docs/my-website/docs/proxy/guardrails/custom_guardrail.md index 5277d46d4..ff3212273 100644 --- a/docs/my-website/docs/proxy/guardrails/custom_guardrail.md +++ b/docs/my-website/docs/proxy/guardrails/custom_guardrail.md @@ -25,7 +25,7 @@ from typing import Any, Dict, List, Literal, Optional, Union import litellm from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata diff --git a/enterprise/enterprise_callbacks/generic_api_callback.py b/enterprise/enterprise_callbacks/generic_api_callback.py index 0b6487a86..eddaa0671 100644 --- a/enterprise/enterprise_callbacks/generic_api_callback.py +++ b/enterprise/enterprise_callbacks/generic_api_callback.py @@ -6,7 +6,7 @@ import dotenv, os import requests from litellm.proxy._types import UserAPIKeyAuth -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from typing import Literal, Union, Optional diff --git a/enterprise/enterprise_hooks/aporia_ai.py b/enterprise/enterprise_hooks/aporia_ai.py index 9da4b891b..27645257e 100644 --- a/enterprise/enterprise_hooks/aporia_ai.py +++ b/enterprise/enterprise_hooks/aporia_ai.py @@ -13,7 +13,7 @@ sys.path.insert( ) # Adds the parent directory to the system path from typing import Optional, Literal, Union, Any import litellm, traceback, sys, uuid -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_guardrail import CustomGuardrail from fastapi import HTTPException diff --git a/enterprise/enterprise_hooks/banned_keywords.py b/enterprise/enterprise_hooks/banned_keywords.py index e282ee5ab..7a6306ed5 100644 --- a/enterprise/enterprise_hooks/banned_keywords.py +++ b/enterprise/enterprise_hooks/banned_keywords.py @@ -9,7 +9,7 @@ from typing import Optional, Literal import litellm -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger from litellm._logging import verbose_proxy_logger diff --git a/enterprise/enterprise_hooks/blocked_user_list.py b/enterprise/enterprise_hooks/blocked_user_list.py index 0bcdbce0c..f978d8756 100644 --- a/enterprise/enterprise_hooks/blocked_user_list.py +++ b/enterprise/enterprise_hooks/blocked_user_list.py @@ -10,7 +10,7 @@ from typing import Optional, Literal import litellm from litellm.proxy.utils import PrismaClient -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth, LiteLLM_EndUserTable from litellm.integrations.custom_logger import CustomLogger from litellm._logging import verbose_proxy_logger diff --git a/enterprise/enterprise_hooks/google_text_moderation.py b/enterprise/enterprise_hooks/google_text_moderation.py index 918e59f46..06d95ff87 100644 --- a/enterprise/enterprise_hooks/google_text_moderation.py +++ b/enterprise/enterprise_hooks/google_text_moderation.py @@ -9,7 +9,7 @@ from typing import Optional, Literal, Union import litellm, traceback, sys, uuid -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException diff --git a/enterprise/enterprise_hooks/llama_guard.py b/enterprise/enterprise_hooks/llama_guard.py index e87bb45ca..5ee6f3b30 100644 --- a/enterprise/enterprise_hooks/llama_guard.py +++ b/enterprise/enterprise_hooks/llama_guard.py @@ -15,7 +15,7 @@ sys.path.insert( ) # Adds the parent directory to the system path from typing import Optional, Literal, Union import litellm, traceback, sys, uuid -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException diff --git a/enterprise/enterprise_hooks/llm_guard.py b/enterprise/enterprise_hooks/llm_guard.py index b8c11ba0f..04ac66211 100644 --- a/enterprise/enterprise_hooks/llm_guard.py +++ b/enterprise/enterprise_hooks/llm_guard.py @@ -13,7 +13,7 @@ import traceback import sys import uuid import os -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException diff --git a/enterprise/enterprise_hooks/openai_moderation.py b/enterprise/enterprise_hooks/openai_moderation.py index a6806ae8a..0b9efc25f 100644 --- a/enterprise/enterprise_hooks/openai_moderation.py +++ b/enterprise/enterprise_hooks/openai_moderation.py @@ -12,7 +12,7 @@ sys.path.insert( ) # Adds the parent directory to the system path from typing import Optional, Literal, Union import litellm, traceback, sys, uuid -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException diff --git a/enterprise/enterprise_hooks/secret_detection.py b/enterprise/enterprise_hooks/secret_detection.py index 0574d3a05..414f3c4dd 100644 --- a/enterprise/enterprise_hooks/secret_detection.py +++ b/enterprise/enterprise_hooks/secret_detection.py @@ -12,7 +12,7 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path from typing import Optional -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm._logging import verbose_proxy_logger import tempfile diff --git a/litellm/__init__.py b/litellm/__init__.py index b1589917a..f6713646d 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -7,7 +7,7 @@ import threading import os from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.caching import Cache +from litellm.caching.caching import Cache from litellm._logging import ( set_verbose, _turn_on_debug, diff --git a/litellm/batch_completion/main.py b/litellm/batch_completion/main.py index e13f5fd2e..426ccfb15 100644 --- a/litellm/batch_completion/main.py +++ b/litellm/batch_completion/main.py @@ -2,7 +2,6 @@ from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait from typing import List, Optional import litellm -from litellm import completion from litellm._logging import print_verbose from litellm.utils import get_optional_params @@ -108,7 +107,7 @@ def batch_completion( if "kwargs" in kwargs_modified: original_kwargs = kwargs_modified.pop("kwargs") future = executor.submit( - completion, **kwargs_modified, **original_kwargs + litellm.completion, **kwargs_modified, **original_kwargs ) completions.append(future) @@ -156,7 +155,7 @@ def batch_completion_models(*args, **kwargs): with ThreadPoolExecutor(max_workers=len(models)) as executor: for model in models: futures[model] = executor.submit( - completion, *args, model=model, **kwargs + litellm.completion, *args, model=model, **kwargs ) for model, future in sorted( @@ -178,7 +177,9 @@ def batch_completion_models(*args, **kwargs): ): # don't override deployment values e.g. model name, api base, etc. deployment[key] = kwargs[key] kwargs = {**deployment, **nested_kwargs} - futures[deployment["model"]] = executor.submit(completion, **kwargs) + futures[deployment["model"]] = executor.submit( + litellm.completion, **kwargs + ) while futures: # wait for the first returned future @@ -246,7 +247,7 @@ def batch_completion_models_all_responses(*args, **kwargs): with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor: for idx, model in enumerate(models): - future = executor.submit(completion, *args, model=model, **kwargs) + future = executor.submit(litellm.completion, *args, model=model, **kwargs) if future.result() is not None: responses.append(future.result()) diff --git a/litellm/caching.py b/litellm/caching/caching.py similarity index 99% rename from litellm/caching.py rename to litellm/caching/caching.py index c9767b624..c16993625 100644 --- a/litellm/caching.py +++ b/litellm/caching/caching.py @@ -212,7 +212,7 @@ class RedisCache(BaseCache): from litellm._service_logger import ServiceLogging - from ._redis import get_redis_client, get_redis_connection_pool + from .._redis import get_redis_client, get_redis_connection_pool redis_kwargs = {} if host is not None: @@ -276,7 +276,7 @@ class RedisCache(BaseCache): ) def init_async_client(self): - from ._redis import get_redis_async_client + from .._redis import get_redis_async_client return get_redis_async_client( connection_pool=self.async_redis_conn_pool, **self.redis_kwargs @@ -302,7 +302,7 @@ class RedisCache(BaseCache): except Exception as e: # NON blocking - notify users Redis is throwing an exception print_verbose( - f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}" + f"litellm.caching.caching: set() - Got exception from REDIS : {str(e)}" ) def increment_cache( @@ -705,7 +705,7 @@ class RedisCache(BaseCache): except Exception as e: # NON blocking - notify users Redis is throwing an exception verbose_logger.error( - "LiteLLM Caching: get() - Got exception from REDIS: ", e + "litellm.caching.caching: get() - Got exception from REDIS: ", e ) def batch_get_cache(self, key_list) -> dict: @@ -781,7 +781,7 @@ class RedisCache(BaseCache): ) # NON blocking - notify users Redis is throwing an exception print_verbose( - f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}" + f"litellm.caching.caching: async get() - Got exception from REDIS: {str(e)}" ) async def async_batch_get_cache(self, key_list) -> dict: diff --git a/litellm/caching/caching_handler.py b/litellm/caching/caching_handler.py new file mode 100644 index 000000000..58de0b42f --- /dev/null +++ b/litellm/caching/caching_handler.py @@ -0,0 +1,440 @@ +""" +This contains LLMCachingHandler + +This exposes two methods: + - async_get_cache + - async_set_cache + +This file is a wrapper around caching.py + +In each method it will call the appropriate method from caching.py +""" + +import asyncio +import datetime +import threading +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple + +from pydantic import BaseModel + +import litellm +from litellm._logging import print_verbose +from litellm.caching.caching import ( + Cache, + QdrantSemanticCache, + RedisCache, + RedisSemanticCache, + S3Cache, +) +from litellm.types.rerank import RerankResponse +from litellm.types.utils import ( + CallTypes, + Embedding, + EmbeddingResponse, + ModelResponse, + TextCompletionResponse, + TranscriptionResponse, +) + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class CachingHandlerResponse(BaseModel): + """ + This is the response object for the caching handler. We need to separate embedding cached responses and (completion / text_completion / transcription) cached responses + + For embeddings there can be a cache hit for some of the inputs in the list and a cache miss for others + """ + + cached_result: Optional[Any] = None + final_embedding_cached_response: Optional[EmbeddingResponse] = None + embedding_all_elements_cache_hit: bool = ( + False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call + ) + + +class LLMCachingHandler: + def __init__(self): + pass + + async def _async_get_cache( + self, + model: str, + original_function: Callable, + logging_obj: LiteLLMLoggingObj, + start_time: datetime.datetime, + call_type: str, + kwargs: Dict[str, Any], + args: Optional[Tuple[Any, ...]] = None, + ) -> CachingHandlerResponse: + """ + Internal method to get from the cache. + Handles different call types (embeddings, chat/completions, text_completion, transcription) + and accordingly returns the cached response + + Args: + model: str: + original_function: Callable: + logging_obj: LiteLLMLoggingObj: + start_time: datetime.datetime: + call_type: str: + kwargs: Dict[str, Any]: + args: Optional[Tuple[Any, ...]] = None: + + + Returns: + CachingHandlerResponse: + Raises: + None + """ + from litellm.utils import ( + CustomStreamWrapper, + convert_to_model_response_object, + convert_to_streaming_response_async, + ) + + args = args or () + + final_embedding_cached_response: Optional[EmbeddingResponse] = None + cached_result: Optional[Any] = None + if ( + (kwargs.get("caching", None) is None and litellm.cache is not None) + or kwargs.get("caching", False) is True + ) and ( + kwargs.get("cache", {}).get("no-cache", False) is not True + ): # allow users to control returning cached responses from the completion function + # checking cache + print_verbose("INSIDE CHECKING CACHE") + if ( + litellm.cache is not None + and litellm.cache.supported_call_types is not None + and str(original_function.__name__) + in litellm.cache.supported_call_types + ): + print_verbose("Checking Cache") + if call_type == CallTypes.aembedding.value and isinstance( + kwargs["input"], list + ): + tasks = [] + for idx, i in enumerate(kwargs["input"]): + preset_cache_key = litellm.cache.get_cache_key( + *args, **{**kwargs, "input": i} + ) + tasks.append( + litellm.cache.async_get_cache(cache_key=preset_cache_key) + ) + cached_result = await asyncio.gather(*tasks) + ## check if cached result is None ## + if cached_result is not None and isinstance(cached_result, list): + # set cached_result to None if all elements are None + if all(result is None for result in cached_result): + cached_result = None + elif isinstance(litellm.cache.cache, RedisSemanticCache) or isinstance( + litellm.cache.cache, RedisCache + ): + preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) + kwargs["preset_cache_key"] = ( + preset_cache_key # for streaming calls, we need to pass the preset_cache_key + ) + cached_result = await litellm.cache.async_get_cache(*args, **kwargs) + elif isinstance(litellm.cache.cache, QdrantSemanticCache): + preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) + kwargs["preset_cache_key"] = ( + preset_cache_key # for streaming calls, we need to pass the preset_cache_key + ) + cached_result = await litellm.cache.async_get_cache(*args, **kwargs) + else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync] + preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) + kwargs["preset_cache_key"] = ( + preset_cache_key # for streaming calls, we need to pass the preset_cache_key + ) + cached_result = litellm.cache.get_cache(*args, **kwargs) + if cached_result is not None and not isinstance(cached_result, list): + print_verbose("Cache Hit!") + cache_hit = True + end_time = datetime.datetime.now() + ( + model, + custom_llm_provider, + dynamic_api_key, + api_base, + ) = litellm.get_llm_provider( + model=model, + custom_llm_provider=kwargs.get("custom_llm_provider", None), + api_base=kwargs.get("api_base", None), + api_key=kwargs.get("api_key", None), + ) + print_verbose( + f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" + ) + logging_obj.update_environment_variables( + model=model, + user=kwargs.get("user", None), + optional_params={}, + litellm_params={ + "logger_fn": kwargs.get("logger_fn", None), + "acompletion": True, + "metadata": kwargs.get("metadata", {}), + "model_info": kwargs.get("model_info", {}), + "proxy_server_request": kwargs.get( + "proxy_server_request", None + ), + "preset_cache_key": kwargs.get("preset_cache_key", None), + "stream_response": kwargs.get("stream_response", {}), + "api_base": kwargs.get("api_base", ""), + }, + input=kwargs.get("messages", ""), + api_key=kwargs.get("api_key", None), + original_response=str(cached_result), + additional_args=None, + stream=kwargs.get("stream", False), + ) + call_type = original_function.__name__ + if call_type == CallTypes.acompletion.value and isinstance( + cached_result, dict + ): + if kwargs.get("stream", False) is True: + cached_result = convert_to_streaming_response_async( + response_object=cached_result, + ) + cached_result = CustomStreamWrapper( + completion_stream=cached_result, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) + else: + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=ModelResponse(), + ) + if call_type == CallTypes.atext_completion.value and isinstance( + cached_result, dict + ): + if kwargs.get("stream", False) is True: + cached_result = convert_to_streaming_response_async( + response_object=cached_result, + ) + cached_result = CustomStreamWrapper( + completion_stream=cached_result, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) + else: + cached_result = TextCompletionResponse(**cached_result) + elif call_type == CallTypes.aembedding.value and isinstance( + cached_result, dict + ): + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=EmbeddingResponse(), + response_type="embedding", + ) + elif call_type == CallTypes.arerank.value and isinstance( + cached_result, dict + ): + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=None, + response_type="rerank", + ) + elif call_type == CallTypes.atranscription.value and isinstance( + cached_result, dict + ): + hidden_params = { + "model": "whisper-1", + "custom_llm_provider": custom_llm_provider, + "cache_hit": True, + } + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=TranscriptionResponse(), + response_type="audio_transcription", + hidden_params=hidden_params, + ) + if kwargs.get("stream", False) is False: + # LOG SUCCESS + asyncio.create_task( + logging_obj.async_success_handler( + cached_result, start_time, end_time, cache_hit + ) + ) + threading.Thread( + target=logging_obj.success_handler, + args=(cached_result, start_time, end_time, cache_hit), + ).start() + cache_key = kwargs.get("preset_cache_key", None) + if ( + isinstance(cached_result, BaseModel) + or isinstance(cached_result, CustomStreamWrapper) + ) and hasattr(cached_result, "_hidden_params"): + cached_result._hidden_params["cache_key"] = cache_key # type: ignore + return CachingHandlerResponse(cached_result=cached_result) + elif ( + call_type == CallTypes.aembedding.value + and cached_result is not None + and isinstance(cached_result, list) + and litellm.cache is not None + and not isinstance( + litellm.cache.cache, S3Cache + ) # s3 doesn't support bulk writing. Exclude. + ): + remaining_list = [] + non_null_list = [] + for idx, cr in enumerate(cached_result): + if cr is None: + remaining_list.append(kwargs["input"][idx]) + else: + non_null_list.append((idx, cr)) + original_kwargs_input = kwargs["input"] + kwargs["input"] = remaining_list + if len(non_null_list) > 0: + print_verbose(f"EMBEDDING CACHE HIT! - {len(non_null_list)}") + final_embedding_cached_response = EmbeddingResponse( + model=kwargs.get("model"), + data=[None] * len(original_kwargs_input), + ) + final_embedding_cached_response._hidden_params["cache_hit"] = ( + True + ) + + for val in non_null_list: + idx, cr = val # (idx, cr) tuple + if cr is not None: + final_embedding_cached_response.data[idx] = Embedding( + embedding=cr["embedding"], + index=idx, + object="embedding", + ) + if len(remaining_list) == 0: + # LOG SUCCESS + cache_hit = True + end_time = datetime.datetime.now() + ( + model, + custom_llm_provider, + dynamic_api_key, + api_base, + ) = litellm.get_llm_provider( + model=model, + custom_llm_provider=kwargs.get("custom_llm_provider", None), + api_base=kwargs.get("api_base", None), + api_key=kwargs.get("api_key", None), + ) + print_verbose( + f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" + ) + logging_obj.update_environment_variables( + model=model, + user=kwargs.get("user", None), + optional_params={}, + litellm_params={ + "logger_fn": kwargs.get("logger_fn", None), + "acompletion": True, + "metadata": kwargs.get("metadata", {}), + "model_info": kwargs.get("model_info", {}), + "proxy_server_request": kwargs.get( + "proxy_server_request", None + ), + "preset_cache_key": kwargs.get( + "preset_cache_key", None + ), + "stream_response": kwargs.get("stream_response", {}), + "api_base": "", + }, + input=kwargs.get("messages", ""), + api_key=kwargs.get("api_key", None), + original_response=str(final_embedding_cached_response), + additional_args=None, + stream=kwargs.get("stream", False), + ) + asyncio.create_task( + logging_obj.async_success_handler( + final_embedding_cached_response, + start_time, + end_time, + cache_hit, + ) + ) + threading.Thread( + target=logging_obj.success_handler, + args=( + final_embedding_cached_response, + start_time, + end_time, + cache_hit, + ), + ).start() + return CachingHandlerResponse( + final_embedding_cached_response=final_embedding_cached_response, + embedding_all_elements_cache_hit=True, + ) + return CachingHandlerResponse( + cached_result=cached_result, + final_embedding_cached_response=final_embedding_cached_response, + ) + + async def _async_set_cache( + self, + result: Any, + original_function: Callable, + kwargs: Dict[str, Any], + args: Optional[Tuple[Any, ...]] = None, + ): + """ + Internal method to check the type of the result & cache used and adds the result to the cache accordingly + + Args: + result: Any: + original_function: Callable: + kwargs: Dict[str, Any]: + args: Optional[Tuple[Any, ...]] = None: + + Returns: + None + Raises: + None + """ + args = args or () + # [OPTIONAL] ADD TO CACHE + if ( + (litellm.cache is not None) + and litellm.cache.supported_call_types is not None + and (str(original_function.__name__) in litellm.cache.supported_call_types) + and (kwargs.get("cache", {}).get("no-store", False) is not True) + ): + if ( + isinstance(result, litellm.ModelResponse) + or isinstance(result, litellm.EmbeddingResponse) + or isinstance(result, TranscriptionResponse) + or isinstance(result, RerankResponse) + ): + if ( + isinstance(result, EmbeddingResponse) + and isinstance(kwargs["input"], list) + and litellm.cache is not None + and not isinstance( + litellm.cache.cache, S3Cache + ) # s3 doesn't support bulk writing. Exclude. + ): + asyncio.create_task( + litellm.cache.async_add_cache_pipeline(result, *args, **kwargs) + ) + elif isinstance(litellm.cache.cache, S3Cache): + threading.Thread( + target=litellm.cache.add_cache, + args=(result,) + args, + kwargs=kwargs, + ).start() + else: + asyncio.create_task( + litellm.cache.async_add_cache(result.json(), *args, **kwargs) + ) + else: + asyncio.create_task( + litellm.cache.async_add_cache(result, *args, **kwargs) + ) diff --git a/litellm/deprecated_litellm_server/server_utils.py b/litellm/deprecated_litellm_server/server_utils.py index 7f8584c46..ac28727fa 100644 --- a/litellm/deprecated_litellm_server/server_utils.py +++ b/litellm/deprecated_litellm_server/server_utils.py @@ -43,7 +43,7 @@ # ### REDIS # # if len(os.getenv("REDIS_HOST", "")) > 0 and len(os.getenv("REDIS_PORT", "")) > 0 and len(os.getenv("REDIS_PASSWORD", "")) > 0: # # print(f"redis host: {os.getenv('REDIS_HOST')}; redis port: {os.getenv('REDIS_PORT')}; password: {os.getenv('REDIS_PASSWORD')}") -# # from litellm.caching import Cache +# # from litellm.caching.caching import Cache # # litellm.cache = Cache(type="redis", host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD")) # # print("\033[92mLiteLLM: Switched on Redis caching\033[0m") diff --git a/litellm/integrations/SlackAlerting/slack_alerting.py b/litellm/integrations/SlackAlerting/slack_alerting.py index b276d37d7..b39d8d2ee 100644 --- a/litellm/integrations/SlackAlerting/slack_alerting.py +++ b/litellm/integrations/SlackAlerting/slack_alerting.py @@ -21,7 +21,7 @@ import litellm.litellm_core_utils import litellm.litellm_core_utils.litellm_logging import litellm.types from litellm._logging import verbose_logger, verbose_proxy_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_batch_logger import CustomBatchLogger from litellm.litellm_core_utils.exception_mapping_utils import ( _add_key_name_and_team_to_alert, diff --git a/litellm/integrations/clickhouse.py b/litellm/integrations/clickhouse.py index 5abcf5eec..e4f43463f 100644 --- a/litellm/integrations/clickhouse.py +++ b/litellm/integrations/clickhouse.py @@ -13,7 +13,7 @@ import requests import litellm from litellm._logging import verbose_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.types.utils import StandardLoggingPayload @@ -29,14 +29,30 @@ def create_client(): clickhouse_host = os.getenv("CLICKHOUSE_HOST") if clickhouse_host is not None: verbose_logger.debug("setting up clickhouse") + + port = os.getenv("CLICKHOUSE_PORT") if port is not None and isinstance(port, str): port = int(port) + host: Optional[str] = os.getenv("CLICKHOUSE_HOST") + if host is None: + raise ValueError("CLICKHOUSE_HOST is not set") + + username: Optional[str] = os.getenv("CLICKHOUSE_USERNAME") + if username is None: + raise ValueError("CLICKHOUSE_USERNAME is not set") + + password: Optional[str] = os.getenv("CLICKHOUSE_PASSWORD") + if password is None: + raise ValueError("CLICKHOUSE_PASSWORD is not set") + if port is None: + raise ValueError("CLICKHOUSE_PORT is not set") + client = clickhouse_connect.get_client( - host=os.getenv("CLICKHOUSE_HOST"), + host=host, port=port, - username=os.getenv("CLICKHOUSE_USERNAME"), - password=os.getenv("CLICKHOUSE_PASSWORD"), + username=username, + password=password, ) return client else: @@ -176,11 +192,29 @@ def _start_clickhouse(): if port is not None and isinstance(port, str): port = int(port) + port = os.getenv("CLICKHOUSE_PORT") + if port is not None and isinstance(port, str): + port = int(port) + + host: Optional[str] = os.getenv("CLICKHOUSE_HOST") + if host is None: + raise ValueError("CLICKHOUSE_HOST is not set") + + username: Optional[str] = os.getenv("CLICKHOUSE_USERNAME") + if username is None: + raise ValueError("CLICKHOUSE_USERNAME is not set") + + password: Optional[str] = os.getenv("CLICKHOUSE_PASSWORD") + if password is None: + raise ValueError("CLICKHOUSE_PASSWORD is not set") + if port is None: + raise ValueError("CLICKHOUSE_PORT is not set") + client = clickhouse_connect.get_client( - host=os.getenv("CLICKHOUSE_HOST"), + host=host, port=port, - username=os.getenv("CLICKHOUSE_USERNAME"), - password=os.getenv("CLICKHOUSE_PASSWORD"), + username=username, + password=password, ) # view all tables in DB response = client.query("SHOW TABLES") @@ -241,11 +275,25 @@ class ClickhouseLogger: if port is not None and isinstance(port, str): port = int(port) + host: Optional[str] = os.getenv("CLICKHOUSE_HOST") + if host is None: + raise ValueError("CLICKHOUSE_HOST is not set") + + username: Optional[str] = os.getenv("CLICKHOUSE_USERNAME") + if username is None: + raise ValueError("CLICKHOUSE_USERNAME is not set") + + password: Optional[str] = os.getenv("CLICKHOUSE_PASSWORD") + if password is None: + raise ValueError("CLICKHOUSE_PASSWORD is not set") + if port is None: + raise ValueError("CLICKHOUSE_PORT is not set") + client = clickhouse_connect.get_client( - host=os.getenv("CLICKHOUSE_HOST"), + host=host, port=port, - username=os.getenv("CLICKHOUSE_USERNAME"), - password=os.getenv("CLICKHOUSE_PASSWORD"), + username=username, + password=password, ) self.client = client diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index d330f4f17..1d23d2904 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -8,7 +8,7 @@ from typing import Any, Literal, Optional, Tuple, Union import dotenv from pydantic import BaseModel -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.types.llms.openai import ChatCompletionRequest from litellm.types.services import ServiceLoggerPayload diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index d3f15e6bc..4bffa94b3 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -23,7 +23,7 @@ from litellm import ( turn_off_message_logging, verbose_logger, ) -from litellm.caching import DualCache, InMemoryCache, S3Cache +from litellm.caching.caching import DualCache, InMemoryCache, S3Cache from litellm.cost_calculator import _select_model_name_for_cost_calc from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_logger import CustomLogger diff --git a/litellm/llms/AzureOpenAI/azure.py b/litellm/llms/AzureOpenAI/azure.py index 8a89970dc..9e6dc07b2 100644 --- a/litellm/llms/AzureOpenAI/azure.py +++ b/litellm/llms/AzureOpenAI/azure.py @@ -10,7 +10,7 @@ from openai import AsyncAzureOpenAI, AzureOpenAI from typing_extensions import overload import litellm -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.types.utils import EmbeddingResponse diff --git a/litellm/llms/base_aws_llm.py b/litellm/llms/base_aws_llm.py index 348a84180..ba1368a10 100644 --- a/litellm/llms/base_aws_llm.py +++ b/litellm/llms/base_aws_llm.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Tuple import httpx from litellm._logging import verbose_logger -from litellm.caching import DualCache, InMemoryCache +from litellm.caching.caching import DualCache, InMemoryCache from litellm.secret_managers.main import get_secret from .base import BaseLLM diff --git a/litellm/llms/bedrock/chat/invoke_handler.py b/litellm/llms/bedrock/chat/invoke_handler.py index aedbe5787..90267da3a 100644 --- a/litellm/llms/bedrock/chat/invoke_handler.py +++ b/litellm/llms/bedrock/chat/invoke_handler.py @@ -29,7 +29,7 @@ import requests # type: ignore import litellm from litellm import verbose_logger -from litellm.caching import InMemoryCache +from litellm.caching.caching import InMemoryCache from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.litellm_logging import Logging from litellm.llms.custom_httpx.http_handler import ( diff --git a/litellm/llms/prompt_templates/image_handling.py b/litellm/llms/prompt_templates/image_handling.py index 54c1fd27b..d9d7c5383 100644 --- a/litellm/llms/prompt_templates/image_handling.py +++ b/litellm/llms/prompt_templates/image_handling.py @@ -8,7 +8,7 @@ from httpx import Response import litellm from litellm import verbose_logger -from litellm.caching import InMemoryCache +from litellm.caching.caching import InMemoryCache from litellm.llms.custom_httpx.http_handler import ( _get_httpx_client, get_async_httpx_client, diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py index d11906b8c..2dafce6a9 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py @@ -4,7 +4,7 @@ from typing import Callable, List, Literal, Optional, Tuple, Union import httpx import litellm -from litellm.caching import Cache +from litellm.caching.caching import Cache from litellm.litellm_core_utils.litellm_logging import Logging from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.OpenAI.openai import AllMessageValues diff --git a/litellm/main.py b/litellm/main.py index 605a71f8c..fe1453836 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -65,7 +65,7 @@ from litellm.utils import ( ) from ._logging import verbose_logger -from .caching import disable_cache, enable_cache, update_cache +from .caching.caching import disable_cache, enable_cache, update_cache from .llms import ( aleph_alpha, baseten, diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 49f2953c1..92276aca8 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -16,7 +16,7 @@ from pydantic import BaseModel import litellm from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import ( LiteLLM_EndUserTable, LiteLLM_JWTAuth, diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index b39064ae6..4d7d64b79 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -15,7 +15,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable from litellm.proxy.utils import PrismaClient diff --git a/litellm/proxy/caching_routes.py b/litellm/proxy/caching_routes.py index 6f07fcb9a..eacd997d3 100644 --- a/litellm/proxy/caching_routes.py +++ b/litellm/proxy/caching_routes.py @@ -5,7 +5,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request import litellm from litellm._logging import verbose_proxy_logger -from litellm.caching import RedisCache +from litellm.caching.caching import RedisCache from litellm.proxy.auth.user_api_key_auth import user_api_key_auth router = APIRouter( diff --git a/litellm/proxy/example_config_yaml/custom_guardrail.py b/litellm/proxy/example_config_yaml/custom_guardrail.py index 598d5e0bd..abd5b672c 100644 --- a/litellm/proxy/example_config_yaml/custom_guardrail.py +++ b/litellm/proxy/example_config_yaml/custom_guardrail.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Literal, Optional, Union import litellm from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata diff --git a/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py b/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py index 781c547aa..3795155b4 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py +++ b/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py @@ -25,7 +25,7 @@ from fastapi import HTTPException import litellm from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.litellm_core_utils.logging_utils import ( convert_litellm_response_object_to_str, diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py index 62ac5bc3d..1127e4e7c 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py +++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py @@ -25,7 +25,7 @@ from fastapi import HTTPException import litellm from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.litellm_core_utils.logging_utils import ( convert_litellm_response_object_to_str, diff --git a/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py b/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py index 0d834f24f..d00586b29 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py +++ b/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Literal, Optional, Union import litellm from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata diff --git a/litellm/proxy/guardrails/guardrail_hooks/presidio.py b/litellm/proxy/guardrails/guardrail_hooks/presidio.py index 0708b27eb..da53e4a8a 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/presidio.py +++ b/litellm/proxy/guardrails/guardrail_hooks/presidio.py @@ -21,7 +21,7 @@ from pydantic import BaseModel import litellm # noqa: E401 from litellm import get_secret from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.proxy._types import UserAPIKeyAuth from litellm.utils import ( diff --git a/litellm/proxy/health_endpoints/_health_endpoints.py b/litellm/proxy/health_endpoints/_health_endpoints.py index f268b429e..f9e3f5320 100644 --- a/litellm/proxy/health_endpoints/_health_endpoints.py +++ b/litellm/proxy/health_endpoints/_health_endpoints.py @@ -465,7 +465,7 @@ async def health_readiness(): # check Cache cache_type = None if litellm.cache is not None: - from litellm.caching import RedisSemanticCache + from litellm.caching.caching import RedisSemanticCache cache_type = litellm.cache.type diff --git a/litellm/proxy/hooks/azure_content_safety.py b/litellm/proxy/hooks/azure_content_safety.py index 58fbb179e..4a5db3b20 100644 --- a/litellm/proxy/hooks/azure_content_safety.py +++ b/litellm/proxy/hooks/azure_content_safety.py @@ -7,7 +7,7 @@ from fastapi import HTTPException import litellm from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth diff --git a/litellm/proxy/hooks/batch_redis_get.py b/litellm/proxy/hooks/batch_redis_get.py index 13d95ab91..a6b69e99f 100644 --- a/litellm/proxy/hooks/batch_redis_get.py +++ b/litellm/proxy/hooks/batch_redis_get.py @@ -11,7 +11,7 @@ from fastapi import HTTPException import litellm from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache, InMemoryCache, RedisCache +from litellm.caching.caching import DualCache, InMemoryCache, RedisCache from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth diff --git a/litellm/proxy/hooks/cache_control_check.py b/litellm/proxy/hooks/cache_control_check.py index d933bfc75..a5e53fc2f 100644 --- a/litellm/proxy/hooks/cache_control_check.py +++ b/litellm/proxy/hooks/cache_control_check.py @@ -7,7 +7,7 @@ from fastapi import HTTPException import litellm from litellm import verbose_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth diff --git a/litellm/proxy/hooks/dynamic_rate_limiter.py b/litellm/proxy/hooks/dynamic_rate_limiter.py index 1ef674b7e..f0b8113c4 100644 --- a/litellm/proxy/hooks/dynamic_rate_limiter.py +++ b/litellm/proxy/hooks/dynamic_rate_limiter.py @@ -14,7 +14,7 @@ from fastapi import HTTPException import litellm from litellm import ModelResponse, Router from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth from litellm.types.router import ModelGroupInfo diff --git a/litellm/proxy/hooks/max_budget_limiter.py b/litellm/proxy/hooks/max_budget_limiter.py index c8686302f..8fa7a33a0 100644 --- a/litellm/proxy/hooks/max_budget_limiter.py +++ b/litellm/proxy/hooks/max_budget_limiter.py @@ -4,7 +4,7 @@ from fastapi import HTTPException import litellm from litellm import verbose_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 46b1fd562..36e5fecff 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -9,7 +9,7 @@ from pydantic import BaseModel import litellm from litellm import ModelResponse from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.proxy._types import CurrentItemRateLimit, UserAPIKeyAuth diff --git a/litellm/proxy/hooks/presidio_pii_masking.py b/litellm/proxy/hooks/presidio_pii_masking.py index ba00630bb..603e07562 100644 --- a/litellm/proxy/hooks/presidio_pii_masking.py +++ b/litellm/proxy/hooks/presidio_pii_masking.py @@ -19,7 +19,7 @@ from fastapi import HTTPException import litellm # noqa: E401 from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth from litellm.utils import ( diff --git a/litellm/proxy/hooks/prompt_injection_detection.py b/litellm/proxy/hooks/prompt_injection_detection.py index dc161f3e5..bbe820ffd 100644 --- a/litellm/proxy/hooks/prompt_injection_detection.py +++ b/litellm/proxy/hooks/prompt_injection_detection.py @@ -18,7 +18,7 @@ from typing_extensions import overload import litellm from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.llms.prompt_templates.factory import prompt_injection_detection_default_pt from litellm.proxy._types import LiteLLMPromptInjectionParams, UserAPIKeyAuth diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8407b4e86..58a4ff346 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -112,7 +112,7 @@ from litellm import ( RetrieveBatchRequest, ) from litellm._logging import verbose_proxy_logger, verbose_router_logger -from litellm.caching import DualCache, RedisCache +from litellm.caching.caching import DualCache, RedisCache from litellm.exceptions import RejectedRequestError from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting from litellm.litellm_core_utils.core_helpers import ( @@ -1554,7 +1554,7 @@ class ProxyConfig: for key, value in litellm_settings.items(): if key == "cache" and value is True: print(f"{blue_color_code}\nSetting Cache on Proxy") # noqa - from litellm.caching import Cache + from litellm.caching.caching import Cache cache_params = {} if "cache_params" in litellm_settings: diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 25de59c8e..db88402f8 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -49,7 +49,7 @@ from litellm import ( ) from litellm._logging import verbose_proxy_logger from litellm._service_logger import ServiceLogging, ServiceTypes -from litellm.caching import DualCache, RedisCache +from litellm.caching.caching import DualCache, RedisCache from litellm.exceptions import RejectedRequestError from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_logger import CustomLogger diff --git a/litellm/router.py b/litellm/router.py index 8b4163cb7..845017465 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -37,7 +37,7 @@ import litellm.litellm_core_utils.exception_mapping_utils from litellm import get_secret_str from litellm._logging import verbose_router_logger from litellm.assistants.main import AssistantDeleted -from litellm.caching import DualCache, InMemoryCache, RedisCache +from litellm.caching.caching import DualCache, InMemoryCache, RedisCache from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc diff --git a/litellm/router_strategy/least_busy.py b/litellm/router_strategy/least_busy.py index a6d0ef5df..f1b35bb89 100644 --- a/litellm/router_strategy/least_busy.py +++ b/litellm/router_strategy/least_busy.py @@ -14,7 +14,7 @@ from typing import Optional import dotenv # type: ignore import requests -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger diff --git a/litellm/router_strategy/lowest_cost.py b/litellm/router_strategy/lowest_cost.py index e8b04c7fb..a3bee348b 100644 --- a/litellm/router_strategy/lowest_cost.py +++ b/litellm/router_strategy/lowest_cost.py @@ -9,7 +9,7 @@ from pydantic import BaseModel import litellm from litellm import ModelResponse, token_counter, verbose_logger from litellm._logging import verbose_router_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger diff --git a/litellm/router_strategy/lowest_latency.py b/litellm/router_strategy/lowest_latency.py index bd5c34335..4eb9c967f 100644 --- a/litellm/router_strategy/lowest_latency.py +++ b/litellm/router_strategy/lowest_latency.py @@ -9,7 +9,7 @@ from pydantic import BaseModel import litellm from litellm import ModelResponse, token_counter, verbose_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger diff --git a/litellm/router_strategy/lowest_tpm_rpm.py b/litellm/router_strategy/lowest_tpm_rpm.py index fd73e0da2..96f655b01 100644 --- a/litellm/router_strategy/lowest_tpm_rpm.py +++ b/litellm/router_strategy/lowest_tpm_rpm.py @@ -12,7 +12,7 @@ from pydantic import BaseModel from litellm import token_counter from litellm._logging import verbose_router_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.utils import print_verbose diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py index 72a49f2bb..e09608422 100644 --- a/litellm/router_strategy/lowest_tpm_rpm_v2.py +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -10,7 +10,7 @@ from pydantic import BaseModel import litellm from litellm import token_counter from litellm._logging import verbose_logger, verbose_router_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.types.router import RouterErrors from litellm.utils import get_utc_datetime, print_verbose diff --git a/litellm/router_utils/cooldown_cache.py b/litellm/router_utils/cooldown_cache.py index 376aced36..e30b4a605 100644 --- a/litellm/router_utils/cooldown_cache.py +++ b/litellm/router_utils/cooldown_cache.py @@ -7,7 +7,7 @@ import time from typing import List, Optional, Tuple, TypedDict from litellm import verbose_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache class CooldownCacheValue(TypedDict): diff --git a/litellm/scheduler.py b/litellm/scheduler.py index 8ee7f0e07..23346e982 100644 --- a/litellm/scheduler.py +++ b/litellm/scheduler.py @@ -1,9 +1,11 @@ -import heapq -from pydantic import BaseModel -from typing import Optional import enum -from litellm.caching import DualCache, RedisCache +import heapq +from typing import Optional + +from pydantic import BaseModel + from litellm import print_verbose +from litellm.caching.caching import DualCache, RedisCache class SchedulerCacheKeys(enum.Enum): diff --git a/litellm/secret_managers/google_secret_manager.py b/litellm/secret_managers/google_secret_manager.py index 7d661555a..f21963c38 100644 --- a/litellm/secret_managers/google_secret_manager.py +++ b/litellm/secret_managers/google_secret_manager.py @@ -4,7 +4,7 @@ from typing import Optional import litellm from litellm._logging import verbose_logger -from litellm.caching import InMemoryCache +from litellm.caching.caching import InMemoryCache from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase from litellm.llms.custom_httpx.http_handler import _get_httpx_client from litellm.proxy._types import CommonProxyErrors, KeyManagementSystem diff --git a/litellm/secret_managers/main.py b/litellm/secret_managers/main.py index 7bcbcab11..4c7cb469b 100644 --- a/litellm/secret_managers/main.py +++ b/litellm/secret_managers/main.py @@ -12,7 +12,7 @@ from dotenv import load_dotenv import litellm from litellm._logging import print_verbose, verbose_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.proxy._types import KeyManagementSystem diff --git a/litellm/utils.py b/litellm/utils.py index 9524838d3..a79a16a58 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -56,7 +56,10 @@ import litellm._service_logger # for storing API inputs, outputs, and metadata import litellm.litellm_core_utils import litellm.litellm_core_utils.audio_utils.utils import litellm.litellm_core_utils.json_validation_rule -from litellm.caching import DualCache +from litellm.caching.caching import DualCache +from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler + +_llm_caching_handler = LLMCachingHandler() from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.exception_mapping_utils import ( @@ -146,7 +149,13 @@ from typing import ( from openai import OpenAIError as OriginalError from ._logging import verbose_logger -from .caching import Cache, QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache +from .caching.caching import ( + Cache, + QdrantSemanticCache, + RedisCache, + RedisSemanticCache, + S3Cache, +) from .exceptions import ( APIConnectionError, APIError, @@ -1121,299 +1130,26 @@ def client(original_function): print_verbose( f"ASYNC kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}; kwargs.get('cache'): {kwargs.get('cache', None)}" ) - # if caching is false, don't run this - final_embedding_cached_response = None - + _caching_handler_response: CachingHandlerResponse = ( + await _llm_caching_handler._async_get_cache( + model=model, + original_function=original_function, + logging_obj=logging_obj, + start_time=start_time, + call_type=call_type, + kwargs=kwargs, + args=args, + ) + ) if ( - (kwargs.get("caching", None) is None and litellm.cache is not None) - or kwargs.get("caching", False) is True - ) and ( - kwargs.get("cache", {}).get("no-cache", False) is not True - ): # allow users to control returning cached responses from the completion function - # checking cache - print_verbose("INSIDE CHECKING CACHE") - if ( - litellm.cache is not None - and litellm.cache.supported_call_types is not None - and str(original_function.__name__) - in litellm.cache.supported_call_types - ): - print_verbose("Checking Cache") - if call_type == CallTypes.aembedding.value and isinstance( - kwargs["input"], list - ): - tasks = [] - for idx, i in enumerate(kwargs["input"]): - preset_cache_key = litellm.cache.get_cache_key( - *args, **{**kwargs, "input": i} - ) - tasks.append( - litellm.cache.async_get_cache( - cache_key=preset_cache_key - ) - ) - cached_result = await asyncio.gather(*tasks) - ## check if cached result is None ## - if cached_result is not None and isinstance( - cached_result, list - ): - if len(cached_result) == 1 and cached_result[0] is None: - cached_result = None - elif isinstance( - litellm.cache.cache, RedisSemanticCache - ) or isinstance(litellm.cache.cache, RedisCache): - preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) - kwargs["preset_cache_key"] = ( - preset_cache_key # for streaming calls, we need to pass the preset_cache_key - ) - cached_result = await litellm.cache.async_get_cache( - *args, **kwargs - ) - elif isinstance(litellm.cache.cache, QdrantSemanticCache): - preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) - kwargs["preset_cache_key"] = ( - preset_cache_key # for streaming calls, we need to pass the preset_cache_key - ) - cached_result = await litellm.cache.async_get_cache( - *args, **kwargs - ) - else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync] - preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) - kwargs["preset_cache_key"] = ( - preset_cache_key # for streaming calls, we need to pass the preset_cache_key - ) - cached_result = litellm.cache.get_cache(*args, **kwargs) - if cached_result is not None and not isinstance( - cached_result, list - ): - print_verbose("Cache Hit!", log_level="INFO") - cache_hit = True - end_time = datetime.datetime.now() - ( - model, - custom_llm_provider, - dynamic_api_key, - api_base, - ) = litellm.get_llm_provider( - model=model, - custom_llm_provider=kwargs.get("custom_llm_provider", None), - api_base=kwargs.get("api_base", None), - api_key=kwargs.get("api_key", None), - ) - print_verbose( - f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" - ) - logging_obj.update_environment_variables( - model=model, - user=kwargs.get("user", None), - optional_params={}, - litellm_params={ - "logger_fn": kwargs.get("logger_fn", None), - "acompletion": True, - "metadata": kwargs.get("metadata", {}), - "model_info": kwargs.get("model_info", {}), - "proxy_server_request": kwargs.get( - "proxy_server_request", None - ), - "preset_cache_key": kwargs.get( - "preset_cache_key", None - ), - "stream_response": kwargs.get("stream_response", {}), - "api_base": kwargs.get("api_base", ""), - }, - input=kwargs.get("messages", ""), - api_key=kwargs.get("api_key", None), - original_response=str(cached_result), - additional_args=None, - stream=kwargs.get("stream", False), - ) - call_type = original_function.__name__ - if call_type == CallTypes.acompletion.value and isinstance( - cached_result, dict - ): - if kwargs.get("stream", False) is True: - cached_result = convert_to_streaming_response_async( - response_object=cached_result, - ) - cached_result = CustomStreamWrapper( - completion_stream=cached_result, - model=model, - custom_llm_provider="cached_response", - logging_obj=logging_obj, - ) - else: - cached_result = convert_to_model_response_object( - response_object=cached_result, - model_response_object=ModelResponse(), - ) - if ( - call_type == CallTypes.atext_completion.value - and isinstance(cached_result, dict) - ): - if kwargs.get("stream", False) is True: - cached_result = convert_to_streaming_response_async( - response_object=cached_result, - ) - cached_result = CustomStreamWrapper( - completion_stream=cached_result, - model=model, - custom_llm_provider="cached_response", - logging_obj=logging_obj, - ) - else: - cached_result = TextCompletionResponse(**cached_result) - elif call_type == CallTypes.aembedding.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - model_response_object=EmbeddingResponse(), - response_type="embedding", - ) - elif call_type == CallTypes.arerank.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - model_response_object=None, - response_type="rerank", - ) - elif call_type == CallTypes.atranscription.value and isinstance( - cached_result, dict - ): - hidden_params = { - "model": "whisper-1", - "custom_llm_provider": custom_llm_provider, - "cache_hit": True, - } - cached_result = convert_to_model_response_object( - response_object=cached_result, - model_response_object=TranscriptionResponse(), - response_type="audio_transcription", - hidden_params=hidden_params, - ) - if kwargs.get("stream", False) is False: - # LOG SUCCESS - asyncio.create_task( - logging_obj.async_success_handler( - cached_result, start_time, end_time, cache_hit - ) - ) - threading.Thread( - target=logging_obj.success_handler, - args=(cached_result, start_time, end_time, cache_hit), - ).start() - cache_key = kwargs.get("preset_cache_key", None) - if ( - isinstance(cached_result, BaseModel) - or isinstance(cached_result, CustomStreamWrapper) - ) and hasattr(cached_result, "_hidden_params"): - cached_result._hidden_params["cache_key"] = cache_key # type: ignore - return cached_result - elif ( - call_type == CallTypes.aembedding.value - and cached_result is not None - and isinstance(cached_result, list) - and litellm.cache is not None - and not isinstance( - litellm.cache.cache, S3Cache - ) # s3 doesn't support bulk writing. Exclude. - ): - remaining_list = [] - non_null_list = [] - for idx, cr in enumerate(cached_result): - if cr is None: - remaining_list.append(kwargs["input"][idx]) - else: - non_null_list.append((idx, cr)) - original_kwargs_input = kwargs["input"] - kwargs["input"] = remaining_list - if len(non_null_list) > 0: - print_verbose( - f"EMBEDDING CACHE HIT! - {len(non_null_list)}" - ) - final_embedding_cached_response = EmbeddingResponse( - model=kwargs.get("model"), - data=[None] * len(original_kwargs_input), - ) - final_embedding_cached_response._hidden_params[ - "cache_hit" - ] = True + _caching_handler_response.cached_result is not None + and _caching_handler_response.final_embedding_cached_response is None + ): + return _caching_handler_response.cached_result + + elif _caching_handler_response.embedding_all_elements_cache_hit is True: + return _caching_handler_response.final_embedding_cached_response - for val in non_null_list: - idx, cr = val # (idx, cr) tuple - if cr is not None: - final_embedding_cached_response.data[idx] = ( - Embedding( - embedding=cr["embedding"], - index=idx, - object="embedding", - ) - ) - if len(remaining_list) == 0: - # LOG SUCCESS - cache_hit = True - end_time = datetime.datetime.now() - ( - model, - custom_llm_provider, - dynamic_api_key, - api_base, - ) = litellm.get_llm_provider( - model=model, - custom_llm_provider=kwargs.get( - "custom_llm_provider", None - ), - api_base=kwargs.get("api_base", None), - api_key=kwargs.get("api_key", None), - ) - print_verbose( - f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" - ) - logging_obj.update_environment_variables( - model=model, - user=kwargs.get("user", None), - optional_params={}, - litellm_params={ - "logger_fn": kwargs.get("logger_fn", None), - "acompletion": True, - "metadata": kwargs.get("metadata", {}), - "model_info": kwargs.get("model_info", {}), - "proxy_server_request": kwargs.get( - "proxy_server_request", None - ), - "preset_cache_key": kwargs.get( - "preset_cache_key", None - ), - "stream_response": kwargs.get( - "stream_response", {} - ), - "api_base": "", - }, - input=kwargs.get("messages", ""), - api_key=kwargs.get("api_key", None), - original_response=str(final_embedding_cached_response), - additional_args=None, - stream=kwargs.get("stream", False), - ) - asyncio.create_task( - logging_obj.async_success_handler( - final_embedding_cached_response, - start_time, - end_time, - cache_hit, - ) - ) - threading.Thread( - target=logging_obj.success_handler, - args=( - final_embedding_cached_response, - start_time, - end_time, - cache_hit, - ), - ).start() - return final_embedding_cached_response # MODEL CALL result = await original_function(*args, **kwargs) end_time = datetime.datetime.now() @@ -1467,51 +1203,14 @@ def client(original_function): original_response=result, model=model, optional_params=kwargs ) - # [OPTIONAL] ADD TO CACHE - if ( - (litellm.cache is not None) - and litellm.cache.supported_call_types is not None - and ( - str(original_function.__name__) - in litellm.cache.supported_call_types - ) - and (kwargs.get("cache", {}).get("no-store", False) is not True) - ): - if ( - isinstance(result, litellm.ModelResponse) - or isinstance(result, litellm.EmbeddingResponse) - or isinstance(result, TranscriptionResponse) - or isinstance(result, RerankResponse) - ): - if ( - isinstance(result, EmbeddingResponse) - and isinstance(kwargs["input"], list) - and litellm.cache is not None - and not isinstance( - litellm.cache.cache, S3Cache - ) # s3 doesn't support bulk writing. Exclude. - ): - asyncio.create_task( - litellm.cache.async_add_cache_pipeline( - result, *args, **kwargs - ) - ) - elif isinstance(litellm.cache.cache, S3Cache): - threading.Thread( - target=litellm.cache.add_cache, - args=(result,) + args, - kwargs=kwargs, - ).start() - else: - asyncio.create_task( - litellm.cache.async_add_cache( - result.json(), *args, **kwargs - ) - ) - else: - asyncio.create_task( - litellm.cache.async_add_cache(result, *args, **kwargs) - ) + ## Add response to cache + await _llm_caching_handler._async_set_cache( + result=result, + original_function=original_function, + kwargs=kwargs, + args=args, + ) + # LOG SUCCESS - handle streaming success logging in the _next_ object print_verbose( f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" @@ -1528,24 +1227,32 @@ def client(original_function): # REBUILD EMBEDDING CACHING if ( isinstance(result, EmbeddingResponse) - and final_embedding_cached_response is not None - and final_embedding_cached_response.data is not None + and _caching_handler_response.final_embedding_cached_response + is not None + and _caching_handler_response.final_embedding_cached_response.data + is not None ): idx = 0 final_data_list = [] - for item in final_embedding_cached_response.data: + for ( + item + ) in _caching_handler_response.final_embedding_cached_response.data: if item is None and result.data is not None: final_data_list.append(result.data[idx]) idx += 1 else: final_data_list.append(item) - final_embedding_cached_response.data = final_data_list - final_embedding_cached_response._hidden_params["cache_hit"] = True - final_embedding_cached_response._response_ms = ( + _caching_handler_response.final_embedding_cached_response.data = ( + final_data_list + ) + _caching_handler_response.final_embedding_cached_response._hidden_params[ + "cache_hit" + ] = True + _caching_handler_response.final_embedding_cached_response._response_ms = ( end_time - start_time ).total_seconds() * 1000 - return final_embedding_cached_response + return _caching_handler_response.final_embedding_cached_response return result except Exception as e: diff --git a/tests/local_testing/test_add_update_models.py b/tests/local_testing/test_add_update_models.py index 738b9025a..b155a7cc5 100644 --- a/tests/local_testing/test_add_update_models.py +++ b/tests/local_testing/test_add_update_models.py @@ -19,7 +19,7 @@ from litellm._logging import verbose_proxy_logger from litellm.proxy.utils import PrismaClient, ProxyLogging verbose_proxy_logger.setLevel(level=logging.DEBUG) -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.router import ( Deployment, updateDeployment, diff --git a/tests/local_testing/test_alerting.py b/tests/local_testing/test_alerting.py index 5785e829b..b79438ffc 100644 --- a/tests/local_testing/test_alerting.py +++ b/tests/local_testing/test_alerting.py @@ -28,7 +28,7 @@ import pytest from openai import APIError import litellm -from litellm.caching import DualCache, RedisCache +from litellm.caching.caching import DualCache, RedisCache from litellm.integrations.SlackAlerting.slack_alerting import ( DeploymentMetrics, SlackAlerting, diff --git a/tests/local_testing/test_auth_checks.py b/tests/local_testing/test_auth_checks.py index 8bc8f7d14..3ea113c28 100644 --- a/tests/local_testing/test_auth_checks.py +++ b/tests/local_testing/test_auth_checks.py @@ -13,7 +13,7 @@ sys.path.insert( ) # Adds the parent directory to the system path import pytest, litellm from litellm.proxy.auth.auth_checks import get_end_user_object -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import LiteLLM_EndUserTable, LiteLLM_BudgetTable from litellm.proxy.utils import PrismaClient diff --git a/tests/local_testing/test_azure_content_safety.py b/tests/local_testing/test_azure_content_safety.py index dc80c163c..91eb92b74 100644 --- a/tests/local_testing/test_azure_content_safety.py +++ b/tests/local_testing/test_azure_content_safety.py @@ -21,7 +21,7 @@ import pytest import litellm from litellm import Router, mock_completion -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.utils import ProxyLogging diff --git a/tests/local_testing/test_banned_keyword_list.py b/tests/local_testing/test_banned_keyword_list.py index 54d8852e8..90066b74f 100644 --- a/tests/local_testing/test_banned_keyword_list.py +++ b/tests/local_testing/test_banned_keyword_list.py @@ -21,7 +21,7 @@ from litellm.proxy.enterprise.enterprise_hooks.banned_keywords import ( from litellm import Router, mock_completion from litellm.proxy.utils import ProxyLogging, hash_token from litellm.proxy._types import UserAPIKeyAuth -from litellm.caching import DualCache +from litellm.caching.caching import DualCache @pytest.mark.asyncio diff --git a/tests/local_testing/test_blocked_user_list.py b/tests/local_testing/test_blocked_user_list.py index fb9986dc5..10635befd 100644 --- a/tests/local_testing/test_blocked_user_list.py +++ b/tests/local_testing/test_blocked_user_list.py @@ -27,7 +27,7 @@ import pytest import litellm from litellm import Router, mock_completion from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import ( _ENTERPRISE_BlockedUserList, @@ -56,7 +56,7 @@ verbose_proxy_logger.setLevel(level=logging.DEBUG) from starlette.datastructures import URL -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import ( BlockUsers, DynamoDBArgs, diff --git a/tests/local_testing/test_caching.py b/tests/local_testing/test_caching.py index 8ba788a55..454068de8 100644 --- a/tests/local_testing/test_caching.py +++ b/tests/local_testing/test_caching.py @@ -21,7 +21,7 @@ import pytest import litellm from litellm import aembedding, completion, embedding -from litellm.caching import Cache +from litellm.caching.caching import Cache from unittest.mock import AsyncMock, patch, MagicMock import datetime @@ -52,7 +52,7 @@ async def test_dual_cache_async_batch_get_cache(): - hit redis for the other -> expect to return None - expect result = [in_memory_result, None] """ - from litellm.caching import DualCache, InMemoryCache, RedisCache + from litellm.caching.caching import DualCache, InMemoryCache, RedisCache in_memory_cache = InMemoryCache() redis_cache = RedisCache() # get credentials from environment @@ -74,7 +74,7 @@ def test_dual_cache_batch_get_cache(): - hit redis for the other -> expect to return None - expect result = [in_memory_result, None] """ - from litellm.caching import DualCache, InMemoryCache, RedisCache + from litellm.caching.caching import DualCache, InMemoryCache, RedisCache in_memory_cache = InMemoryCache() redis_cache = RedisCache() # get credentials from environment @@ -520,6 +520,7 @@ async def test_embedding_caching_azure_individual_items_reordered(): assert embedding_val_1[0]["id"] == embedding_val_2[0]["id"] ``` """ + litellm.set_verbose = True litellm.cache = Cache() common_msg = f"{uuid.uuid4()}" common_msg_2 = f"hey how's it going {uuid.uuid4()}" @@ -532,9 +533,11 @@ async def test_embedding_caching_azure_individual_items_reordered(): embedding_val_1 = await aembedding( model="azure/azure-embedding-model", input=embedding_1, caching=True ) + print("embedding val 1", embedding_val_1) embedding_val_2 = await aembedding( model="azure/azure-embedding-model", input=embedding_2, caching=True ) + print("embedding val 2", embedding_val_2) print(f"embedding_val_2._hidden_params: {embedding_val_2._hidden_params}") assert embedding_val_2._hidden_params["cache_hit"] == True @@ -866,7 +869,7 @@ async def test_redis_cache_cluster_init_unit_test(): from redis.asyncio import RedisCluster as AsyncRedisCluster from redis.cluster import RedisCluster - from litellm.caching import RedisCache + from litellm.caching.caching import RedisCache litellm.set_verbose = True @@ -900,7 +903,7 @@ async def test_redis_cache_cluster_init_with_env_vars_unit_test(): from redis.asyncio import RedisCluster as AsyncRedisCluster from redis.cluster import RedisCluster - from litellm.caching import RedisCache + from litellm.caching.caching import RedisCache litellm.set_verbose = True @@ -1554,7 +1557,7 @@ def test_custom_redis_cache_params(): def test_get_cache_key(): - from litellm.caching import Cache + from litellm.caching.caching import Cache try: print("Testing get_cache_key") @@ -1989,7 +1992,7 @@ async def test_cache_default_off_acompletion(): verbose_logger.setLevel(logging.DEBUG) - from litellm.caching import CacheMode + from litellm.caching.caching import CacheMode random_number = random.randint( 1, 100000 @@ -2072,7 +2075,7 @@ async def test_dual_cache_uses_redis(): - Assert that value from redis is used """ litellm.set_verbose = True - from litellm.caching import DualCache, RedisCache + from litellm.caching.caching import DualCache, RedisCache current_usage = uuid.uuid4() @@ -2095,7 +2098,7 @@ async def test_proxy_logging_setup(): """ Assert always_read_redis is True when used by internal usage cache """ - from litellm.caching import DualCache + from litellm.caching.caching import DualCache from litellm.proxy.utils import ProxyLogging pl_obj = ProxyLogging(user_api_key_cache=DualCache()) @@ -2165,7 +2168,7 @@ async def test_redis_proxy_batch_redis_get_cache(): - make 2nd call -> expect hit """ - from litellm.caching import Cache, DualCache + from litellm.caching.caching import Cache, DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.hooks.batch_redis_get import _PROXY_BatchRedisRequests diff --git a/tests/local_testing/test_caching_ssl.py b/tests/local_testing/test_caching_ssl.py index 84ece8310..0825a8537 100644 --- a/tests/local_testing/test_caching_ssl.py +++ b/tests/local_testing/test_caching_ssl.py @@ -15,7 +15,7 @@ sys.path.insert( import pytest import litellm from litellm import embedding, completion, Router -from litellm.caching import Cache +from litellm.caching.caching import Cache messages = [{"role": "user", "content": f"who is ishaan {time.time()}"}] diff --git a/tests/local_testing/test_datadog.py b/tests/local_testing/test_datadog.py index af496aadf..990a5b76c 100644 --- a/tests/local_testing/test_datadog.py +++ b/tests/local_testing/test_datadog.py @@ -151,7 +151,7 @@ async def test_datadog_log_redis_failures(): Test that poorly configured Redis is logged as Warning on DataDog """ try: - from litellm.caching import Cache + from litellm.caching.caching import Cache from litellm.integrations.datadog.datadog import DataDogLogger litellm.cache = Cache( diff --git a/tests/local_testing/test_jwt.py b/tests/local_testing/test_jwt.py index f226dee37..4bd3f2613 100644 --- a/tests/local_testing/test_jwt.py +++ b/tests/local_testing/test_jwt.py @@ -24,7 +24,7 @@ import pytest from fastapi import Request import litellm -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable, LiteLLMRoutes from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.management_endpoints.team_endpoints import new_team diff --git a/tests/local_testing/test_key_generate_prisma.py b/tests/local_testing/test_key_generate_prisma.py index 4574dd5ae..4098e524a 100644 --- a/tests/local_testing/test_key_generate_prisma.py +++ b/tests/local_testing/test_key_generate_prisma.py @@ -89,7 +89,7 @@ verbose_proxy_logger.setLevel(level=logging.DEBUG) from starlette.datastructures import URL -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import ( DynamoDBArgs, GenerateKeyRequest, @@ -1444,7 +1444,7 @@ def test_call_with_key_over_budget(prisma_client): # update spend using track_cost callback, make 2nd request, it should fail from litellm import Choices, Message, ModelResponse, Usage - from litellm.caching import Cache + from litellm.caching.caching import Cache from litellm.proxy.proxy_server import ( _PROXY_track_cost_callback as track_cost_callback, ) @@ -1564,7 +1564,7 @@ def test_call_with_key_over_budget_no_cache(prisma_client): setattr(litellm.proxy.proxy_server, "proxy_batch_write_at", 1) from litellm import Choices, Message, ModelResponse, Usage - from litellm.caching import Cache + from litellm.caching.caching import Cache litellm.cache = Cache() import time @@ -1685,7 +1685,7 @@ def test_call_with_key_over_model_budget(prisma_client): # update spend using track_cost callback, make 2nd request, it should fail from litellm import Choices, Message, ModelResponse, Usage - from litellm.caching import Cache + from litellm.caching.caching import Cache from litellm.proxy.proxy_server import ( _PROXY_track_cost_callback as track_cost_callback, ) diff --git a/tests/local_testing/test_lakera_ai_prompt_injection.py b/tests/local_testing/test_lakera_ai_prompt_injection.py index 37da1b426..f9035a74f 100644 --- a/tests/local_testing/test_lakera_ai_prompt_injection.py +++ b/tests/local_testing/test_lakera_ai_prompt_injection.py @@ -25,7 +25,7 @@ import pytest import litellm from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import lakeraAI_Moderation from litellm.proxy.proxy_server import embeddings diff --git a/tests/local_testing/test_least_busy_routing.py b/tests/local_testing/test_least_busy_routing.py index 8d85b708f..dc7db9560 100644 --- a/tests/local_testing/test_least_busy_routing.py +++ b/tests/local_testing/test_least_busy_routing.py @@ -20,7 +20,7 @@ import pytest import litellm from litellm import Router -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.router_strategy.least_busy import LeastBusyLoggingHandler ### UNIT TESTS FOR LEAST BUSY LOGGING ### diff --git a/tests/local_testing/test_llm_guard.py b/tests/local_testing/test_llm_guard.py index 4775e065d..ff380b74d 100644 --- a/tests/local_testing/test_llm_guard.py +++ b/tests/local_testing/test_llm_guard.py @@ -20,7 +20,7 @@ from litellm.proxy.enterprise.enterprise_hooks.llm_guard import _ENTERPRISE_LLMG from litellm import Router, mock_completion from litellm.proxy.utils import ProxyLogging, hash_token from litellm.proxy._types import UserAPIKeyAuth -from litellm.caching import DualCache +from litellm.caching.caching import DualCache ### UNIT TESTS FOR LLM GUARD ### diff --git a/tests/local_testing/test_load_test_router_s3.py b/tests/local_testing/test_load_test_router_s3.py index 7b2683367..3a022ae99 100644 --- a/tests/local_testing/test_load_test_router_s3.py +++ b/tests/local_testing/test_load_test_router_s3.py @@ -10,7 +10,7 @@ # import asyncio # from litellm import Router, Timeout # import time -# from litellm.caching import Cache +# from litellm.caching.caching import Cache # import litellm # litellm.cache = Cache( diff --git a/tests/local_testing/test_lowest_cost_routing.py b/tests/local_testing/test_lowest_cost_routing.py index a793ba0a2..4e3105b5f 100644 --- a/tests/local_testing/test_lowest_cost_routing.py +++ b/tests/local_testing/test_lowest_cost_routing.py @@ -15,7 +15,7 @@ sys.path.insert( import pytest from litellm import Router from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler -from litellm.caching import DualCache +from litellm.caching.caching import DualCache ### UNIT TESTS FOR cost ROUTING ### diff --git a/tests/local_testing/test_lowest_latency_routing.py b/tests/local_testing/test_lowest_latency_routing.py index 3b6255e53..423449098 100644 --- a/tests/local_testing/test_lowest_latency_routing.py +++ b/tests/local_testing/test_lowest_latency_routing.py @@ -22,7 +22,7 @@ import pytest import litellm from litellm import Router -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler ### UNIT TESTS FOR LATENCY ROUTING ### diff --git a/tests/local_testing/test_max_tpm_rpm_limiter.py b/tests/local_testing/test_max_tpm_rpm_limiter.py index 43489d5d9..29f9a85c4 100644 --- a/tests/local_testing/test_max_tpm_rpm_limiter.py +++ b/tests/local_testing/test_max_tpm_rpm_limiter.py @@ -19,7 +19,7 @@ # from litellm import Router # from litellm.proxy.utils import ProxyLogging, hash_token # from litellm.proxy._types import UserAPIKeyAuth -# from litellm.caching import DualCache, RedisCache +# from litellm.caching.caching import DualCache, RedisCache # from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter # from datetime import datetime diff --git a/tests/local_testing/test_openai_moderations_hook.py b/tests/local_testing/test_openai_moderations_hook.py index 745f188fc..2ab866995 100644 --- a/tests/local_testing/test_openai_moderations_hook.py +++ b/tests/local_testing/test_openai_moderations_hook.py @@ -22,7 +22,7 @@ from litellm.proxy.enterprise.enterprise_hooks.openai_moderation import ( from litellm import Router, mock_completion from litellm.proxy.utils import ProxyLogging, hash_token from litellm.proxy._types import UserAPIKeyAuth -from litellm.caching import DualCache +from litellm.caching.caching import DualCache ### UNIT TESTS FOR OpenAI Moderation ### diff --git a/tests/local_testing/test_parallel_request_limiter.py b/tests/local_testing/test_parallel_request_limiter.py index b52d4fc69..d0a9f9843 100644 --- a/tests/local_testing/test_parallel_request_limiter.py +++ b/tests/local_testing/test_parallel_request_limiter.py @@ -23,7 +23,7 @@ import pytest import litellm from litellm import Router -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.hooks.parallel_request_limiter import ( _PROXY_MaxParallelRequestsHandler as MaxParallelRequestsHandler, diff --git a/tests/local_testing/test_presidio_masking.py b/tests/local_testing/test_presidio_masking.py index 35a03ea5e..0f96da334 100644 --- a/tests/local_testing/test_presidio_masking.py +++ b/tests/local_testing/test_presidio_masking.py @@ -22,7 +22,7 @@ import pytest import litellm from litellm import Router, mock_completion -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.hooks.presidio_pii_masking import _OPTIONAL_PresidioPIIMasking from litellm.proxy.utils import ProxyLogging diff --git a/tests/local_testing/test_prometheus_service.py b/tests/local_testing/test_prometheus_service.py index 9e3441abb..86321ea2d 100644 --- a/tests/local_testing/test_prometheus_service.py +++ b/tests/local_testing/test_prometheus_service.py @@ -67,7 +67,7 @@ async def test_completion_with_caching_bad_call(): litellm.set_verbose = True try: - from litellm.caching import RedisCache + from litellm.caching.caching import RedisCache litellm.service_callback = ["prometheus_system"] sl = ServiceLogging(mock_testing=True) diff --git a/tests/local_testing/test_prompt_injection_detection.py b/tests/local_testing/test_prompt_injection_detection.py index e170dbf81..c493a3722 100644 --- a/tests/local_testing/test_prompt_injection_detection.py +++ b/tests/local_testing/test_prompt_injection_detection.py @@ -20,7 +20,7 @@ from litellm.proxy.hooks.prompt_injection_detection import ( from litellm import Router, mock_completion from litellm.proxy.utils import ProxyLogging from litellm.proxy._types import UserAPIKeyAuth, LiteLLMPromptInjectionParams -from litellm.caching import DualCache +from litellm.caching.caching import DualCache @pytest.mark.asyncio diff --git a/tests/local_testing/test_proxy_reject_logging.py b/tests/local_testing/test_proxy_reject_logging.py index 2b6bcaab2..756a23115 100644 --- a/tests/local_testing/test_proxy_reject_logging.py +++ b/tests/local_testing/test_proxy_reject_logging.py @@ -31,7 +31,7 @@ from starlette.datastructures import URL import litellm from litellm import Router, mock_completion -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.enterprise.enterprise_hooks.secret_detection import ( diff --git a/tests/local_testing/test_proxy_server.py b/tests/local_testing/test_proxy_server.py index 98b5f058d..fcb23b125 100644 --- a/tests/local_testing/test_proxy_server.py +++ b/tests/local_testing/test_proxy_server.py @@ -745,7 +745,7 @@ async def test_team_update_redis(): """ Tests if team update, updates the redis cache if set """ - from litellm.caching import DualCache, RedisCache + from litellm.caching.caching import DualCache, RedisCache from litellm.proxy._types import LiteLLM_TeamTableCachedObj from litellm.proxy.auth.auth_checks import _cache_team_object @@ -775,7 +775,7 @@ async def test_get_team_redis(client_no_auth): """ Tests if get_team_object gets value from redis cache, if set """ - from litellm.caching import DualCache, RedisCache + from litellm.caching.caching import DualCache, RedisCache from litellm.proxy.auth.auth_checks import get_team_object proxy_logging_obj: ProxyLogging = getattr( diff --git a/tests/local_testing/test_secret_detect_hook.py b/tests/local_testing/test_secret_detect_hook.py index 2c2007164..e931198e8 100644 --- a/tests/local_testing/test_secret_detect_hook.py +++ b/tests/local_testing/test_secret_detect_hook.py @@ -26,7 +26,7 @@ from starlette.datastructures import URL import litellm from litellm import Router, mock_completion -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.enterprise.enterprise_hooks.secret_detection import ( diff --git a/tests/local_testing/test_streaming.py b/tests/local_testing/test_streaming.py index d64134aa8..b912d98f3 100644 --- a/tests/local_testing/test_streaming.py +++ b/tests/local_testing/test_streaming.py @@ -3128,7 +3128,7 @@ async def test_azure_astreaming_and_function_calling(): "content": f"What is the weather like in Boston? {uuid.uuid4()}", } ] - from litellm.caching import Cache + from litellm.caching.caching import Cache litellm.cache = Cache( type="redis", diff --git a/tests/local_testing/test_tpm_rpm_routing_v2.py b/tests/local_testing/test_tpm_rpm_routing_v2.py index 259bd0ee0..61b17d356 100644 --- a/tests/local_testing/test_tpm_rpm_routing_v2.py +++ b/tests/local_testing/test_tpm_rpm_routing_v2.py @@ -23,7 +23,7 @@ import pytest import litellm from litellm import Router -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.router_strategy.lowest_tpm_rpm_v2 import ( LowestTPMLoggingHandler_v2 as LowestTPMLoggingHandler, ) diff --git a/tests/local_testing/test_update_spend.py b/tests/local_testing/test_update_spend.py index e716f9b4b..3eb9f1ab4 100644 --- a/tests/local_testing/test_update_spend.py +++ b/tests/local_testing/test_update_spend.py @@ -27,7 +27,7 @@ import pytest import litellm from litellm import Router, mock_completion from litellm._logging import verbose_proxy_logger -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.management_endpoints.internal_user_endpoints import ( new_user, @@ -53,7 +53,7 @@ verbose_proxy_logger.setLevel(level=logging.DEBUG) from starlette.datastructures import URL -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import ( BlockUsers, DynamoDBArgs, diff --git a/tests/local_testing/test_whisper.py b/tests/local_testing/test_whisper.py index 44da9b386..087028928 100644 --- a/tests/local_testing/test_whisper.py +++ b/tests/local_testing/test_whisper.py @@ -157,7 +157,7 @@ async def test_transcription_on_router(): @pytest.mark.asyncio() async def test_transcription_caching(): import litellm - from litellm.caching import Cache + from litellm.caching.caching import Cache litellm.set_verbose = True litellm.cache = Cache() diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py index 8890ba844..bb1445e9a 100644 --- a/tests/proxy_admin_ui_tests/test_key_management.py +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -71,7 +71,7 @@ verbose_proxy_logger.setLevel(level=logging.DEBUG) from starlette.datastructures import URL -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import ( DynamoDBArgs, GenerateKeyRequest, diff --git a/tests/proxy_admin_ui_tests/test_role_based_access.py b/tests/proxy_admin_ui_tests/test_role_based_access.py index d851ca568..e2727e5d8 100644 --- a/tests/proxy_admin_ui_tests/test_role_based_access.py +++ b/tests/proxy_admin_ui_tests/test_role_based_access.py @@ -78,7 +78,7 @@ verbose_proxy_logger.setLevel(level=logging.DEBUG) from starlette.datastructures import URL -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import * proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) diff --git a/tests/proxy_admin_ui_tests/test_sso_sign_in.py b/tests/proxy_admin_ui_tests/test_sso_sign_in.py index 5814b427f..17ee445ac 100644 --- a/tests/proxy_admin_ui_tests/test_sso_sign_in.py +++ b/tests/proxy_admin_ui_tests/test_sso_sign_in.py @@ -17,7 +17,7 @@ from litellm.proxy._types import LitellmUserRoles import os import jwt import time -from litellm.caching import DualCache +from litellm.caching.caching import DualCache proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) diff --git a/tests/proxy_admin_ui_tests/test_usage_endpoints.py b/tests/proxy_admin_ui_tests/test_usage_endpoints.py index 6b5946c38..4a9ba9588 100644 --- a/tests/proxy_admin_ui_tests/test_usage_endpoints.py +++ b/tests/proxy_admin_ui_tests/test_usage_endpoints.py @@ -85,7 +85,7 @@ verbose_proxy_logger.setLevel(level=logging.DEBUG) from starlette.datastructures import URL -from litellm.caching import DualCache +from litellm.caching.caching import DualCache from litellm.proxy._types import ( DynamoDBArgs, GenerateKeyRequest,