diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 75d0ba55f..897e9c3b2 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -936,7 +936,14 @@ "mode": "chat" }, "openrouter/mistralai/mistral-7b-instruct": { - "max_tokens": 4096, + "max_tokens": 8192, + "input_cost_per_token": 0.00000013, + "output_cost_per_token": 0.00000013, + "litellm_provider": "openrouter", + "mode": "chat" + }, + "openrouter/mistralai/mistral-7b-instruct:free": { + "max_tokens": 8192, "input_cost_per_token": 0.0, "output_cost_per_token": 0.0, "litellm_provider": "openrouter", diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 579fe6583..25d531fda 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -2,7 +2,7 @@ ## This test asserts the type of data passed into each method of the custom callback handler import sys, os, time, inspect, asyncio, traceback from datetime import datetime -import pytest +import pytest, uuid from pydantic import BaseModel sys.path.insert(0, os.path.abspath("../..")) @@ -795,6 +795,50 @@ async def test_async_completion_azure_caching(): assert len(customHandler_caching.states) == 4 # pre, post, success, success +@pytest.mark.asyncio +async def test_async_completion_azure_caching_streaming(): + litellm.set_verbose = True + customHandler_caching = CompletionCustomHandler() + litellm.cache = Cache( + type="redis", + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) + litellm.callbacks = [customHandler_caching] + unique_time = uuid.uuid4() + response1 = await litellm.acompletion( + model="azure/chatgpt-v-2", + messages=[ + {"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"} + ], + caching=True, + stream=True, + ) + async for chunk in response1: + continue + await asyncio.sleep(1) + print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}") + response2 = await litellm.acompletion( + model="azure/chatgpt-v-2", + messages=[ + {"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"} + ], + caching=True, + stream=True, + ) + async for chunk in response2: + continue + await asyncio.sleep(1) # success callbacks are done in parallel + print( + f"customHandler_caching.states post-cache hit: {customHandler_caching.states}" + ) + assert len(customHandler_caching.errors) == 0 + assert ( + len(customHandler_caching.states) > 4 + ) # pre, post, streaming .., success, success + + @pytest.mark.asyncio async def test_async_embedding_azure_caching(): print("Testing custom callback input - Azure Caching") diff --git a/litellm/utils.py b/litellm/utils.py index 0133db50b..a7f8c378d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2328,6 +2328,13 @@ def client(original_function): model_response_object=ModelResponse(), stream=kwargs.get("stream", False), ) + + cached_result = CustomStreamWrapper( + completion_stream=cached_result, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) elif call_type == CallTypes.embedding.value and isinstance( cached_result, dict ): @@ -2624,28 +2631,6 @@ def client(original_function): cached_result, list ): print_verbose(f"Cache Hit!") - call_type = original_function.__name__ - if call_type == CallTypes.acompletion.value and isinstance( - cached_result, dict - ): - if kwargs.get("stream", False) == True: - cached_result = convert_to_streaming_response_async( - response_object=cached_result, - ) - else: - cached_result = convert_to_model_response_object( - response_object=cached_result, - model_response_object=ModelResponse(), - ) - elif call_type == CallTypes.aembedding.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - model_response_object=EmbeddingResponse(), - response_type="embedding", - ) - # LOG SUCCESS cache_hit = True end_time = datetime.datetime.now() ( @@ -2685,15 +2670,44 @@ def client(original_function): additional_args=None, stream=kwargs.get("stream", False), ) - asyncio.create_task( - logging_obj.async_success_handler( - cached_result, start_time, end_time, cache_hit + call_type = original_function.__name__ + if call_type == CallTypes.acompletion.value and isinstance( + cached_result, dict + ): + if kwargs.get("stream", False) == True: + cached_result = convert_to_streaming_response_async( + response_object=cached_result, + ) + cached_result = CustomStreamWrapper( + completion_stream=cached_result, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) + else: + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=ModelResponse(), + ) + elif call_type == CallTypes.aembedding.value and isinstance( + cached_result, dict + ): + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=EmbeddingResponse(), + response_type="embedding", ) - ) - threading.Thread( - target=logging_obj.success_handler, - args=(cached_result, start_time, end_time, cache_hit), - ).start() + if kwargs.get("stream", False) == False: + # LOG SUCCESS + asyncio.create_task( + logging_obj.async_success_handler( + cached_result, start_time, end_time, cache_hit + ) + ) + threading.Thread( + target=logging_obj.success_handler, + args=(cached_result, start_time, end_time, cache_hit), + ).start() return cached_result elif ( call_type == CallTypes.aembedding.value @@ -4296,7 +4310,9 @@ def get_optional_params( parameters=tool["function"].get("parameters", {}), ) gtool_func_declarations.append(gtool_func_declaration) - optional_params["tools"] = [generative_models.Tool(function_declarations=gtool_func_declarations)] + optional_params["tools"] = [ + generative_models.Tool(function_declarations=gtool_func_declarations) + ] elif custom_llm_provider == "sagemaker": ## check if unsupported param passed in supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] @@ -8524,6 +8540,19 @@ class CustomStreamWrapper: ] elif self.custom_llm_provider == "text-completion-openai": response_obj = self.handle_openai_text_completion_chunk(chunk) + completion_obj["content"] = response_obj["text"] + print_verbose(f"completion obj content: {completion_obj['content']}") + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] + elif self.custom_llm_provider == "cached_response": + response_obj = { + "text": chunk.choices[0].delta.content, + "is_finished": True, + "finish_reason": chunk.choices[0].finish_reason, + } + completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") if response_obj["is_finished"]: @@ -8732,6 +8761,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "vertex_ai" or self.custom_llm_provider == "sagemaker" or self.custom_llm_provider == "gemini" + or self.custom_llm_provider == "cached_response" or self.custom_llm_provider in litellm.openai_compatible_endpoints ): async for chunk in self.completion_stream: