diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 2c3c863de5..80dc912220 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -707,6 +707,40 @@ async def test_redis_cache_acompletion_stream(): # test_redis_cache_acompletion_stream() +@pytest.mark.asyncio +async def test_redis_cache_atext_completion(): + try: + litellm.set_verbose = True + prompt = f"write a one sentence poem about: {uuid.uuid4()}" + litellm.cache = Cache( + type="redis", + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + supported_call_types=["atext_completion"], + ) + print("test for caching, atext_completion") + + response1 = await litellm.atext_completion( + model="gpt-3.5-turbo-instruct", prompt=prompt, max_tokens=40, temperature=1 + ) + + await asyncio.sleep(0.5) + print("\n\n Response 1 content: ", response1, "\n\n") + + response2 = await litellm.atext_completion( + model="gpt-3.5-turbo-instruct", prompt=prompt, max_tokens=40, temperature=1 + ) + + print(response2) + + assert response1 == response2 + assert response1.choices == response2.choices + except Exception as e: + print(f"{str(e)}\n\n{traceback.format_exc()}") + raise e + + @pytest.mark.asyncio async def test_redis_cache_acompletion_stream_bedrock(): import asyncio diff --git a/litellm/utils.py b/litellm/utils.py index 537a800059..d1133affb4 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2996,7 +2996,7 @@ def client(original_function): ) ): # allow users to control returning cached responses from the completion function # checking cache - print_verbose(f"INSIDE CHECKING CACHE") + print_verbose("INSIDE CHECKING CACHE") if ( litellm.cache is not None and str(original_function.__name__) @@ -3103,6 +3103,22 @@ def client(original_function): response_object=cached_result, model_response_object=ModelResponse(), ) + if ( + call_type == CallTypes.atext_completion.value + and isinstance(cached_result, dict) + ): + if kwargs.get("stream", False) == True: + cached_result = convert_to_streaming_response_async( + response_object=cached_result, + ) + cached_result = CustomStreamWrapper( + completion_stream=cached_result, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) + else: + cached_result = TextCompletionResponse(**cached_result) elif call_type == CallTypes.aembedding.value and isinstance( cached_result, dict ):