From df9df7b040f8fffc7c35908e32d2a4950650fbf8 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 11 Jan 2024 16:30:05 +0530 Subject: [PATCH 01/22] fix: n --- litellm/caching.py | 12 +-- litellm/main.py | 4 +- litellm/tests/test_caching.py | 153 +++++++++++----------------------- 3 files changed, 58 insertions(+), 111 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index e1678a109..2c01a17c6 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -81,7 +81,7 @@ class RedisCache(BaseCache): def set_cache(self, key, value, **kwargs): ttl = kwargs.get("ttl", None) - print_verbose(f"Set Redis Cache: key: {key}\nValue {value}") + print_verbose(f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}") try: self.redis_client.set(name=key, value=str(value), ex=ttl) except Exception as e: @@ -171,7 +171,7 @@ class S3Cache(BaseCache): CacheControl=cache_control, ContentType="application/json", ContentLanguage="en", - ContentDisposition=f"inline; filename=\"{key}.json\"" + ContentDisposition=f'inline; filename="{key}.json"', ) else: cache_control = "immutable, max-age=31536000, s-maxage=31536000" @@ -183,7 +183,7 @@ class S3Cache(BaseCache): CacheControl=cache_control, ContentType="application/json", ContentLanguage="en", - ContentDisposition=f"inline; filename=\"{key}.json\"" + ContentDisposition=f'inline; filename="{key}.json"', ) except Exception as e: # NON blocking - notify users S3 is throwing an exception @@ -495,7 +495,6 @@ class Cache: cached_result is not None and isinstance(cached_result, dict) and "timestamp" in cached_result - and max_age is not None ): timestamp = cached_result["timestamp"] current_time = time.time() @@ -504,7 +503,7 @@ class Cache: response_age = current_time - timestamp # Check if the cached response is older than the max-age - if response_age > max_age: + if max_age is not None and response_age > max_age: print_verbose( f"Cached response for key {cache_key} is too old. Max-age: {max_age}s, Age: {response_age}s" ) @@ -565,6 +564,9 @@ class Cache: async def _async_add_cache(self, result, *args, **kwargs): self.add_cache(result, *args, **kwargs) + async def _async_get_cache(self, *args, **kwargs): + return self.get_cache(*args, **kwargs) + def enable_cache( type: Optional[Literal["local", "redis", "s3"]] = "local", diff --git a/litellm/main.py b/litellm/main.py index 2b53c3a5f..cb67774bf 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -267,10 +267,10 @@ async def acompletion( elif asyncio.iscoroutine(init_response): response = await init_response else: - response = init_response + response = init_response # type: ignore else: # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) + response = await loop.run_in_executor(None, func_with_context) # type: ignore # if kwargs.get("stream", False): # return an async generator # return _async_streaming( # response=response, diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 7b8290604..ec99e9c95 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -11,10 +11,10 @@ sys.path.insert( ) # Adds the parent directory to the system path import pytest import litellm -from litellm import embedding, completion +from litellm import embedding, completion, aembedding from litellm.caching import Cache import random -import hashlib +import hashlib, asyncio # litellm.set_verbose=True @@ -261,6 +261,51 @@ def test_embedding_caching_azure(): # test_embedding_caching_azure() +@pytest.mark.asyncio +async def test_embedding_caching_azure_individual_items(): + """ + Tests caching for individual items in an embedding list + + Assert if the same embeddingresponse object is returned for the duplicate item in 2 embedding list calls + + ``` + embedding_1 = ["hey how's it going", "I'm doing well"] + embedding_val_1 = embedding(...) + + embedding_2 = ["hey how's it going", "I'm fine"] + embedding_val_2 = embedding(...) + + assert embedding_val_1[0]["id"] == embedding_val_2[0]["id"] + ``` + """ + litellm.cache = Cache() + common_msg = f"hey how's it going {uuid.uuid4()}" + embedding_1 = [common_msg, "I'm doing well"] + embedding_2 = [common_msg, "I'm fine"] + + embedding_val_1 = await aembedding( + model="azure/azure-embedding-model", input=embedding_1, caching=True + ) + + embedding_val_2 = await aembedding( + model="azure/azure-embedding-model", input=embedding_2, caching=True + ) + if ( + embedding_val_2["data"][0]["embedding"] + != embedding_val_1["data"][0]["embedding"] + ): + print(f"embedding1: {embedding_val_1}") + print(f"embedding2: {embedding_val_2}") + pytest.fail("Error occurred: Embedding caching failed") + if ( + embedding_val_2["data"][1]["embedding"] + == embedding_val_1["data"][1]["embedding"] + ): + print(f"embedding1: {embedding_val_1}") + print(f"embedding2: {embedding_val_2}") + pytest.fail("Error occurred: Embedding caching failed") + + def test_redis_cache_completion(): litellm.set_verbose = False @@ -401,14 +446,14 @@ def test_redis_cache_completion_stream(): """ -test_redis_cache_completion_stream() +# test_redis_cache_completion_stream() def test_redis_cache_acompletion_stream(): import asyncio try: - litellm.set_verbose = True + litellm.set_verbose = False random_word = generate_random_word() messages = [ { @@ -436,7 +481,6 @@ def test_redis_cache_acompletion_stream(): stream=True, ) async for chunk in response1: - print(chunk) response_1_content += chunk.choices[0].delta.content or "" print(response_1_content) @@ -454,7 +498,6 @@ def test_redis_cache_acompletion_stream(): stream=True, ) async for chunk in response2: - print(chunk) response_2_content += chunk.choices[0].delta.content or "" print(response_2_content) @@ -916,101 +959,3 @@ def test_cache_context_managers(): # test_cache_context_managers() - -# test_custom_redis_cache_params() - -# def test_redis_cache_with_ttl(): -# cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) -# sample_model_response_object_str = """{ -# "choices": [ -# { -# "finish_reason": "stop", -# "index": 0, -# "message": { -# "role": "assistant", -# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic." -# } -# } -# ], -# "created": 1691429984.3852863, -# "model": "claude-instant-1", -# "usage": { -# "prompt_tokens": 18, -# "completion_tokens": 23, -# "total_tokens": 41 -# } -# }""" -# sample_model_response_object = { -# "choices": [ -# { -# "finish_reason": "stop", -# "index": 0, -# "message": { -# "role": "assistant", -# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic." -# } -# } -# ], -# "created": 1691429984.3852863, -# "model": "claude-instant-1", -# "usage": { -# "prompt_tokens": 18, -# "completion_tokens": 23, -# "total_tokens": 41 -# } -# } -# cache.add_cache(cache_key="test_key", result=sample_model_response_object_str, ttl=1) -# cached_value = cache.get_cache(cache_key="test_key") -# print(f"cached-value: {cached_value}") -# assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content'] -# time.sleep(2) -# assert cache.get_cache(cache_key="test_key") is None - -# # test_redis_cache_with_ttl() - -# def test_in_memory_cache_with_ttl(): -# cache = Cache(type="local") -# sample_model_response_object_str = """{ -# "choices": [ -# { -# "finish_reason": "stop", -# "index": 0, -# "message": { -# "role": "assistant", -# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic." -# } -# } -# ], -# "created": 1691429984.3852863, -# "model": "claude-instant-1", -# "usage": { -# "prompt_tokens": 18, -# "completion_tokens": 23, -# "total_tokens": 41 -# } -# }""" -# sample_model_response_object = { -# "choices": [ -# { -# "finish_reason": "stop", -# "index": 0, -# "message": { -# "role": "assistant", -# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic." -# } -# } -# ], -# "created": 1691429984.3852863, -# "model": "claude-instant-1", -# "usage": { -# "prompt_tokens": 18, -# "completion_tokens": 23, -# "total_tokens": 41 -# } -# } -# cache.add_cache(cache_key="test_key", result=sample_model_response_object_str, ttl=1) -# cached_value = cache.get_cache(cache_key="test_key") -# assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content'] -# time.sleep(2) -# assert cache.get_cache(cache_key="test_key") is None -# # test_in_memory_cache_with_ttl() From 2cd5f0fbe92517d310e3f80ef8dfd7ff9b71b408 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 11 Jan 2024 16:51:34 +0530 Subject: [PATCH 02/22] fix(utils.py): support caching individual items in embedding input list https://github.com/BerriAI/litellm/issues/1350 --- litellm/main.py | 2 +- litellm/utils.py | 82 +++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 78 insertions(+), 6 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index cb67774bf..3ec82ed0a 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -234,7 +234,7 @@ async def acompletion( } try: # Use a partial function to pass your keyword arguments - func = partial(completion, **completion_kwargs, **kwargs) + func = partial(completion, **completion_kwargs) # Add the context to the function ctx = contextvars.copy_context() diff --git a/litellm/utils.py b/litellm/utils.py index fcf6e9dea..49bb47420 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2174,6 +2174,7 @@ def client(original_function): result = None logging_obj = kwargs.get("litellm_logging_obj", None) # only set litellm_call_id if its not in kwargs + call_type = original_function.__name__ if "litellm_call_id" not in kwargs: kwargs["litellm_call_id"] = str(uuid.uuid4()) try: @@ -2204,6 +2205,7 @@ def client(original_function): f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}" ) # if caching is false, don't run this + final_embedding_cached_response = None if ( (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True @@ -2220,8 +2222,24 @@ def client(original_function): in litellm.cache.supported_call_types ): print_verbose(f"Checking Cache") - cached_result = litellm.cache.get_cache(*args, **kwargs) - if cached_result != None: + if call_type == CallTypes.aembedding.value and isinstance( + kwargs["input"], list + ): + tasks = [] + embedding_kwargs = copy.deepcopy(kwargs) + for idx, i in enumerate(kwargs["input"]): + embedding_kwargs["input"] = i + tasks.append( + litellm.cache._async_get_cache( + *args, **embedding_kwargs + ) + ) + cached_result = await asyncio.gather(*tasks) + else: + cached_result = litellm.cache.get_cache(*args, **kwargs) + if cached_result is not None and not isinstance( + cached_result, list + ): print_verbose(f"Cache Hit!") call_type = original_function.__name__ if call_type == CallTypes.acompletion.value and isinstance( @@ -2294,6 +2312,30 @@ def client(original_function): args=(cached_result, start_time, end_time, cache_hit), ).start() return cached_result + elif ( + call_type == CallTypes.aembedding.value + and cached_result is not None + and isinstance(cached_result, list) + ): + remaining_list = [] + non_null_list = [] + for idx, cr in enumerate(cached_result): + if cr is None: + remaining_list.append(kwargs["input"][idx]) + else: + non_null_list.append((idx, cr)) + original_kwargs_input = kwargs["input"] + kwargs["input"] = remaining_list + + if len(non_null_list) > 0: + final_embedding_cached_response = EmbeddingResponse( + model=kwargs.get("model"), data=[] + ) + + for val in non_null_list: + idx, cr = val # (idx, cr) tuple + if cr is not None: + final_embedding_cached_response.data[idx] = val # MODEL CALL result = await original_function(*args, **kwargs) end_time = datetime.datetime.now() @@ -2323,9 +2365,23 @@ def client(original_function): if isinstance(result, litellm.ModelResponse) or isinstance( result, litellm.EmbeddingResponse ): - asyncio.create_task( - litellm.cache._async_add_cache(result.json(), *args, **kwargs) - ) + if isinstance(result, EmbeddingResponse) and isinstance( + kwargs["input"], list + ): + embedding_kwargs = copy.deepcopy(kwargs) + for idx, i in enumerate(kwargs["input"]): + embedding_response = result.data[idx] + asyncio.create_task( + litellm.cache._async_add_cache( + embedding_response, *args, **embedding_kwargs + ) + ) + else: + asyncio.create_task( + litellm.cache._async_add_cache( + result.json(), *args, **kwargs + ) + ) else: asyncio.create_task( litellm.cache._async_add_cache(result, *args, **kwargs) @@ -2349,6 +2405,22 @@ def client(original_function): result._response_ms = ( end_time - start_time ).total_seconds() * 1000 # return response latency in ms like openai + elif ( + isinstance(result, EmbeddingResponse) + and final_embedding_cached_response is not None + ): + idx = 0 + final_data_list = [] + for item in final_embedding_cached_response.data: + if item is None: + final_data_list.append(result.data[idx]) + else: + final_data_list.append(item) + idx += 1 + + final_embedding_cached_response.data = final_data_list + return final_embedding_cached_response + return result except Exception as e: traceback_exception = traceback.format_exc() From 252c8415c6795741cbfa8136df90b3fb47c81307 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 11 Jan 2024 16:55:19 +0530 Subject: [PATCH 03/22] fix(main.py): add back **kwargs for acompletion --- litellm/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/main.py b/litellm/main.py index 3ec82ed0a..cb67774bf 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -234,7 +234,7 @@ async def acompletion( } try: # Use a partial function to pass your keyword arguments - func = partial(completion, **completion_kwargs) + func = partial(completion, **completion_kwargs, **kwargs) # Add the context to the function ctx = contextvars.copy_context() From bfa26dd5b34ce8eb1b9865ba07433df471dadadc Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 11 Jan 2024 18:14:22 +0530 Subject: [PATCH 04/22] fix(utils.py): bug fixes --- litellm/tests/test_custom_callback_input.py | 7 +- litellm/utils.py | 90 +++++++++++++++++++-- 2 files changed, 88 insertions(+), 9 deletions(-) diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 0fb69b645..d9364d11e 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -545,8 +545,9 @@ async def test_async_chat_bedrock_stream(): # asyncio.run(test_async_chat_bedrock_stream()) -# Text Completion - +# Text Completion + + ## Test OpenAI text completion + Async @pytest.mark.asyncio async def test_async_text_completion_openai_stream(): @@ -585,6 +586,7 @@ async def test_async_text_completion_openai_stream(): except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") + # EMBEDDING ## Test OpenAI + Async @pytest.mark.asyncio @@ -758,6 +760,7 @@ async def test_async_embedding_azure_caching(): ) await asyncio.sleep(1) # success callbacks are done in parallel print(customHandler_caching.states) + print(customHandler_caching.errors) assert len(customHandler_caching.errors) == 0 assert len(customHandler_caching.states) == 4 # pre, post, success, success diff --git a/litellm/utils.py b/litellm/utils.py index 49bb47420..e4c67c0e8 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2026,12 +2026,17 @@ def client(original_function): ) # if caching is false or cache["no-cache"]==True, don't run this if ( - (kwargs.get("caching", None) is None and litellm.cache is not None) - or kwargs.get("caching", False) == True - or ( - kwargs.get("cache", None) is not None - and kwargs.get("cache", {}).get("no-cache", False) != True + ( + (kwargs.get("caching", None) is None and litellm.cache is not None) + or kwargs.get("caching", False) == True + or ( + kwargs.get("cache", None) is not None + and kwargs.get("cache", {}).get("no-cache", False) != True + ) ) + and kwargs.get("aembedding", False) != True + and kwargs.get("acompletion", False) != True + and kwargs.get("aimg_generation", False) != True ): # allow users to control returning cached responses from the completion function # checking cache print_verbose(f"INSIDE CHECKING CACHE") @@ -2329,13 +2334,78 @@ def client(original_function): if len(non_null_list) > 0: final_embedding_cached_response = EmbeddingResponse( - model=kwargs.get("model"), data=[] + model=kwargs.get("model"), + data=[None] * len(original_kwargs_input), ) for val in non_null_list: idx, cr = val # (idx, cr) tuple if cr is not None: final_embedding_cached_response.data[idx] = val + + if len(remaining_list) == 0: + # LOG SUCCESS + cache_hit = True + end_time = datetime.datetime.now() + ( + model, + custom_llm_provider, + dynamic_api_key, + api_base, + ) = litellm.get_llm_provider( + model=model, + custom_llm_provider=kwargs.get( + "custom_llm_provider", None + ), + api_base=kwargs.get("api_base", None), + api_key=kwargs.get("api_key", None), + ) + print_verbose( + f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" + ) + logging_obj.update_environment_variables( + model=model, + user=kwargs.get("user", None), + optional_params={}, + litellm_params={ + "logger_fn": kwargs.get("logger_fn", None), + "acompletion": True, + "metadata": kwargs.get("metadata", {}), + "model_info": kwargs.get("model_info", {}), + "proxy_server_request": kwargs.get( + "proxy_server_request", None + ), + "preset_cache_key": kwargs.get( + "preset_cache_key", None + ), + "stream_response": kwargs.get( + "stream_response", {} + ), + }, + input=kwargs.get("messages", ""), + api_key=kwargs.get("api_key", None), + original_response=str(final_embedding_cached_response), + additional_args=None, + stream=kwargs.get("stream", False), + ) + asyncio.create_task( + logging_obj.async_success_handler( + final_embedding_cached_response, + start_time, + end_time, + cache_hit, + ) + ) + threading.Thread( + target=logging_obj.success_handler, + args=( + final_embedding_cached_response, + start_time, + end_time, + cache_hit, + ), + ).start() + return final_embedding_cached_response # MODEL CALL result = await original_function(*args, **kwargs) end_time = datetime.datetime.now() @@ -2371,6 +2441,7 @@ def client(original_function): embedding_kwargs = copy.deepcopy(kwargs) for idx, i in enumerate(kwargs["input"]): embedding_response = result.data[idx] + embedding_kwargs["input"] = i asyncio.create_task( litellm.cache._async_add_cache( embedding_response, *args, **embedding_kwargs @@ -5971,7 +6042,12 @@ def exception_type( message=f"BedrockException - {original_exception.message}", llm_provider="bedrock", model=model, - response=original_exception.response, + response=httpx.Response( + status_code=500, + request=httpx.Request( + method="POST", url="https://api.openai.com/v1/" + ), + ), ) elif original_exception.status_code == 401: exception_mapping_worked = True From 1378190dbf0ba95bbec71f7bdc67a40048be85a5 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 11 Jan 2024 18:30:10 +0530 Subject: [PATCH 05/22] fix(main.py): init custom llm provider earlier --- litellm/main.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index cb67774bf..19e192b40 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -232,6 +232,9 @@ async def acompletion( "model_list": model_list, "acompletion": True, # assuming this is a required parameter } + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=completion_kwargs.get("base_url", None) + ) try: # Use a partial function to pass your keyword arguments func = partial(completion, **completion_kwargs, **kwargs) @@ -240,10 +243,6 @@ async def acompletion( ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) - _, custom_llm_provider, _, _ = get_llm_provider( - model=model, api_base=completion_kwargs.get("base_url", None) - ) - if ( custom_llm_provider == "openai" or custom_llm_provider == "azure" From 0440168915f858aa17753f82540b1d31684619ee Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 11 Jan 2024 18:44:58 +0530 Subject: [PATCH 06/22] test(test_custom_callback_input.py): make test more verbsoe n --- litellm/tests/test_custom_callback_input.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index d9364d11e..537e38a28 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -706,6 +706,7 @@ async def test_async_embedding_bedrock(): ## Test Azure - completion, embedding @pytest.mark.asyncio async def test_async_completion_azure_caching(): + litellm.set_verbose = True customHandler_caching = CompletionCustomHandler() litellm.cache = Cache( type="redis", From f8d17206904ae86e6eec7fdbfa53cc03bcfc4ab5 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 11 Jan 2024 19:02:17 +0530 Subject: [PATCH 07/22] fix(utils.py): bug fixes --- litellm/tests/test_caching.py | 1 + litellm/tests/test_custom_callback_input.py | 2 +- litellm/utils.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index ec99e9c95..11d4fda15 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -290,6 +290,7 @@ async def test_embedding_caching_azure_individual_items(): embedding_val_2 = await aembedding( model="azure/azure-embedding-model", input=embedding_2, caching=True ) + print(f"embedding_val_2: {embedding_val_2}") if ( embedding_val_2["data"][0]["embedding"] != embedding_val_1["data"][0]["embedding"] diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 537e38a28..8b41dfabb 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -673,7 +673,7 @@ async def test_async_embedding_bedrock(): response = await litellm.aembedding( model="bedrock/cohere.embed-multilingual-v3", input=["good morning from litellm"], - aws_region_name="os.environ/AWS_REGION_NAME_2", + aws_region_name="us-east-1", ) await asyncio.sleep(1) print(f"customHandler_success.errors: {customHandler_success.errors}") diff --git a/litellm/utils.py b/litellm/utils.py index e4c67c0e8..45889835e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2341,7 +2341,7 @@ def client(original_function): for val in non_null_list: idx, cr = val # (idx, cr) tuple if cr is not None: - final_embedding_cached_response.data[idx] = val + final_embedding_cached_response.data[idx] = cr if len(remaining_list) == 0: # LOG SUCCESS @@ -2485,9 +2485,9 @@ def client(original_function): for item in final_embedding_cached_response.data: if item is None: final_data_list.append(result.data[idx]) + idx += 1 else: final_data_list.append(item) - idx += 1 final_embedding_cached_response.data = final_data_list return final_embedding_cached_response From a1af7688cee504d0c08c9b9f954b5d7db1cdf079 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 11 Jan 2024 19:30:28 +0530 Subject: [PATCH 08/22] fix(utils.py): use preset cache key for async calls as well --- litellm/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/litellm/utils.py b/litellm/utils.py index 45889835e..f6eab167d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2241,6 +2241,10 @@ def client(original_function): ) cached_result = await asyncio.gather(*tasks) else: + preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) + kwargs[ + "preset_cache_key" + ] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key cached_result = litellm.cache.get_cache(*args, **kwargs) if cached_result is not None and not isinstance( cached_result, list From 007870390d497ea9022ee520e375602ceb7f8dbb Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 12 Jan 2024 21:46:41 +0530 Subject: [PATCH 09/22] fix: support async redis caching --- litellm/_redis.py | 22 ++- litellm/caching.py | 243 +++++++++++++++++++++++++--------- litellm/proxy/proxy_cli.py | 58 ++++---- litellm/proxy/proxy_server.py | 18 ++- litellm/tests/test_caching.py | 106 ++++++++++++--- litellm/utils.py | 32 +++-- 6 files changed, 357 insertions(+), 122 deletions(-) diff --git a/litellm/_redis.py b/litellm/_redis.py index bee73f134..36f4ef870 100644 --- a/litellm/_redis.py +++ b/litellm/_redis.py @@ -11,6 +11,7 @@ import os import inspect import redis, litellm +import redis.asyncio as async_redis from typing import List, Optional @@ -67,7 +68,10 @@ def get_redis_url_from_environment(): ) -def get_redis_client(**env_overrides): +def _get_redis_client_logic(**env_overrides): + """ + Common functionality across sync + async redis client implementations + """ ### check if "os.environ/" passed in for k, v in env_overrides.items(): if isinstance(v, str) and v.startswith("os.environ/"): @@ -85,9 +89,21 @@ def get_redis_client(**env_overrides): redis_kwargs.pop("port", None) redis_kwargs.pop("db", None) redis_kwargs.pop("password", None) - - return redis.Redis.from_url(**redis_kwargs) elif "host" not in redis_kwargs or redis_kwargs["host"] is None: raise ValueError("Either 'host' or 'url' must be specified for redis.") litellm.print_verbose(f"redis_kwargs: {redis_kwargs}") + return redis_kwargs + + +def get_redis_client(**env_overrides): + redis_kwargs = _get_redis_client_logic(**env_overrides) + if "url" in redis_kwargs and redis_kwargs["url"] is not None: + return redis.Redis.from_url(**redis_kwargs) return redis.Redis(**redis_kwargs) + + +def get_redis_async_client(**env_overrides): + redis_kwargs = _get_redis_client_logic(**env_overrides) + if "url" in redis_kwargs and redis_kwargs["url"] is not None: + return async_redis.Redis.from_url(**redis_kwargs) + return async_redis.Redis(socket_timeout=5, **redis_kwargs) diff --git a/litellm/caching.py b/litellm/caching.py index 2c01a17c6..b89220e8d 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -26,9 +26,18 @@ class BaseCache: def set_cache(self, key, value, **kwargs): raise NotImplementedError + async def async_set_cache(self, key, value, **kwargs): + raise NotImplementedError + def get_cache(self, key, **kwargs): raise NotImplementedError + async def async_get_cache(self, key, **kwargs): + raise NotImplementedError + + async def disconnect(self): + raise NotImplementedError + class InMemoryCache(BaseCache): def __init__(self): @@ -41,6 +50,9 @@ class InMemoryCache(BaseCache): if "ttl" in kwargs: self.ttl_dict[key] = time.time() + kwargs["ttl"] + async def async_set_cache(self, key, value, **kwargs): + self.set_cache(key=key, value=value, **kwargs) + def get_cache(self, key, **kwargs): if key in self.cache_dict: if key in self.ttl_dict: @@ -55,16 +67,21 @@ class InMemoryCache(BaseCache): return cached_response return None + async def async_get_cache(self, key, **kwargs): + return self.get_cache(key=key, **kwargs) + def flush_cache(self): self.cache_dict.clear() self.ttl_dict.clear() + async def disconnect(self): + pass + class RedisCache(BaseCache): - def __init__(self, host=None, port=None, password=None, **kwargs): - import redis + # if users don't provider one, use the default litellm cache - # if users don't provider one, use the default litellm cache + def __init__(self, host=None, port=None, password=None, **kwargs): from ._redis import get_redis_client redis_kwargs = {} @@ -76,8 +93,13 @@ class RedisCache(BaseCache): redis_kwargs["password"] = password redis_kwargs.update(kwargs) - self.redis_client = get_redis_client(**redis_kwargs) + self.redis_kwargs = redis_kwargs + + def init_async_client(self): + from ._redis import get_redis_async_client + + return get_redis_async_client(**self.redis_kwargs) def set_cache(self, key, value, **kwargs): ttl = kwargs.get("ttl", None) @@ -88,6 +110,34 @@ class RedisCache(BaseCache): # NON blocking - notify users Redis is throwing an exception logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) + async def async_set_cache(self, key, value, **kwargs): + async with self.init_async_client() as redis_client: + ttl = kwargs.get("ttl", None) + print_verbose( + f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" + ) + try: + await redis_client.set(name=key, value=str(value), ex=ttl) + except Exception as e: + # NON blocking - notify users Redis is throwing an exception + logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) + + def _get_cache_logic(self, cached_response: Any): + """ + Common 'get_cache_logic' across sync + async redis client implementations + """ + if cached_response is None: + return cached_response + # cached_response is in `b{} convert it to ModelResponse + cached_response = cached_response.decode("utf-8") # Convert bytes to string + try: + cached_response = json.loads( + cached_response + ) # Convert string to dictionary + except: + cached_response = ast.literal_eval(cached_response) + return cached_response + def get_cache(self, key, **kwargs): try: print_verbose(f"Get Redis Cache: key: {key}") @@ -95,26 +145,33 @@ class RedisCache(BaseCache): print_verbose( f"Got Redis Cache: key: {key}, cached_response {cached_response}" ) - if cached_response != None: - # cached_response is in `b{} convert it to ModelResponse - cached_response = cached_response.decode( - "utf-8" - ) # Convert bytes to string - try: - cached_response = json.loads( - cached_response - ) # Convert string to dictionary - except: - cached_response = ast.literal_eval(cached_response) - return cached_response + return self._get_cache_logic(cached_response=cached_response) except Exception as e: # NON blocking - notify users Redis is throwing an exception traceback.print_exc() logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) + async def async_get_cache(self, key, **kwargs): + async with self.init_async_client() as redis_client: + try: + print_verbose(f"Get Redis Cache: key: {key}") + cached_response = await redis_client.get(key) + print_verbose( + f"Got Async Redis Cache: key: {key}, cached_response {cached_response}" + ) + response = self._get_cache_logic(cached_response=cached_response) + return response + except Exception as e: + # NON blocking - notify users Redis is throwing an exception + traceback.print_exc() + logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) + def flush_cache(self): self.redis_client.flushall() + async def disconnect(self): + pass + class S3Cache(BaseCache): def __init__( @@ -189,6 +246,9 @@ class S3Cache(BaseCache): # NON blocking - notify users S3 is throwing an exception print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}") + async def async_set_cache(self, key, value, **kwargs): + self.set_cache(key=key, value=value, **kwargs) + def get_cache(self, key, **kwargs): import boto3, botocore @@ -229,6 +289,9 @@ class S3Cache(BaseCache): traceback.print_exc() print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}") + async def async_get_cache(self, key, **kwargs): + return self.get_cache(key=key, **kwargs) + def flush_cache(self): pass @@ -468,6 +531,45 @@ class Cache: } time.sleep(0.02) + def _get_cache_logic( + self, + cached_result: Optional[Any], + max_age: Optional[float], + ): + """ + Common get cache logic across sync + async implementations + """ + # Check if a timestamp was stored with the cached response + if ( + cached_result is not None + and isinstance(cached_result, dict) + and "timestamp" in cached_result + ): + timestamp = cached_result["timestamp"] + current_time = time.time() + + # Calculate age of the cached response + response_age = current_time - timestamp + + # Check if the cached response is older than the max-age + if max_age is not None and response_age > max_age: + return None # Cached response is too old + + # If the response is fresh, or there's no max-age requirement, return the cached response + # cached_response is in `b{} convert it to ModelResponse + cached_response = cached_result.get("response") + try: + if isinstance(cached_response, dict): + pass + else: + cached_response = json.loads( + cached_response # type: ignore + ) # Convert string to dictionary + except: + cached_response = ast.literal_eval(cached_response) # type: ignore + return cached_response + return cached_result + def get_cache(self, *args, **kwargs): """ Retrieves the cached result for the given arguments. @@ -490,53 +592,40 @@ class Cache: "s-max-age", cache_control_args.get("s-maxage", float("inf")) ) cached_result = self.cache.get_cache(cache_key) - # Check if a timestamp was stored with the cached response - if ( - cached_result is not None - and isinstance(cached_result, dict) - and "timestamp" in cached_result - ): - timestamp = cached_result["timestamp"] - current_time = time.time() - - # Calculate age of the cached response - response_age = current_time - timestamp - - # Check if the cached response is older than the max-age - if max_age is not None and response_age > max_age: - print_verbose( - f"Cached response for key {cache_key} is too old. Max-age: {max_age}s, Age: {response_age}s" - ) - return None # Cached response is too old - - # If the response is fresh, or there's no max-age requirement, return the cached response - # cached_response is in `b{} convert it to ModelResponse - cached_response = cached_result.get("response") - try: - if isinstance(cached_response, dict): - pass - else: - cached_response = json.loads( - cached_response - ) # Convert string to dictionary - except: - cached_response = ast.literal_eval(cached_response) - return cached_response - return cached_result + return self._get_cache_logic( + cached_result=cached_result, max_age=max_age + ) except Exception as e: print_verbose(f"An exception occurred: {traceback.format_exc()}") return None - def add_cache(self, result, *args, **kwargs): + async def async_get_cache(self, *args, **kwargs): """ - Adds a result to the cache. + Async get cache implementation. - Args: - *args: args to litellm.completion() or embedding() - **kwargs: kwargs to litellm.completion() or embedding() + Used for embedding calls in async wrapper + """ + try: # never block execution + if "cache_key" in kwargs: + cache_key = kwargs["cache_key"] + else: + cache_key = self.get_cache_key(*args, **kwargs) + if cache_key is not None: + cache_control_args = kwargs.get("cache", {}) + max_age = cache_control_args.get( + "s-max-age", cache_control_args.get("s-maxage", float("inf")) + ) + cached_result = await self.cache.async_get_cache(cache_key) + return self._get_cache_logic( + cached_result=cached_result, max_age=max_age + ) + except Exception as e: + print_verbose(f"An exception occurred: {traceback.format_exc()}") + return None - Returns: - None + def _add_cache_logic(self, result, *args, **kwargs): + """ + Common implementation across sync + async add_cache functions """ try: if "cache_key" in kwargs: @@ -555,17 +644,49 @@ class Cache: if k == "ttl": kwargs["ttl"] = v cached_data = {"timestamp": time.time(), "response": result} - self.cache.set_cache(cache_key, cached_data, **kwargs) + return cache_key, cached_data + else: + raise Exception("cache key is None") + except Exception as e: + raise e + + def add_cache(self, result, *args, **kwargs): + """ + Adds a result to the cache. + + Args: + *args: args to litellm.completion() or embedding() + **kwargs: kwargs to litellm.completion() or embedding() + + Returns: + None + """ + try: + cache_key, cached_data = self._add_cache_logic( + result=result, *args, **kwargs + ) + self.cache.set_cache(cache_key, cached_data, **kwargs) except Exception as e: print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") traceback.print_exc() pass - async def _async_add_cache(self, result, *args, **kwargs): - self.add_cache(result, *args, **kwargs) + async def async_add_cache(self, result, *args, **kwargs): + """ + Async implementation of add_cache + """ + try: + cache_key, cached_data = self._add_cache_logic( + result=result, *args, **kwargs + ) + await self.cache.async_set_cache(cache_key, cached_data, **kwargs) + except Exception as e: + print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") + traceback.print_exc() - async def _async_get_cache(self, *args, **kwargs): - return self.get_cache(*args, **kwargs) + async def disconnect(self): + if hasattr(self.cache, "disconnect"): + await self.cache.disconnect() def enable_cache( diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index c06ba7d32..9a1a01d66 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -346,7 +346,7 @@ def run_server( import gunicorn.app.base except: raise ImportError( - "Uvicorn, gunicorn needs to be imported. Run - `pip 'litellm[proxy]'`" + "uvicorn, gunicorn needs to be imported. Run - `pip install 'litellm[proxy]'`" ) if config is not None: @@ -427,36 +427,40 @@ def run_server( f"\033[1;34mSee all Router/Swagger docs on http://0.0.0.0:{port} \033[0m\n" ) # noqa - # Gunicorn Application Class - class StandaloneApplication(gunicorn.app.base.BaseApplication): - def __init__(self, app, options=None): - self.options = options or {} # gunicorn options - self.application = app # FastAPI app - super().__init__() + uvicorn.run( + "litellm.proxy.proxy_server:app", host=host, port=port, workers=num_workers + ) - def load_config(self): - # note: This Loads the gunicorn config - has nothing to do with LiteLLM Proxy config - config = { - key: value - for key, value in self.options.items() - if key in self.cfg.settings and value is not None - } - for key, value in config.items(): - self.cfg.set(key.lower(), value) + # # Gunicorn Application Class + # class StandaloneApplication(gunicorn.app.base.BaseApplication): + # def __init__(self, app, options=None): + # self.options = options or {} # gunicorn options + # self.application = app # FastAPI app + # super().__init__() - def load(self): - # gunicorn app function - return self.application + # def load_config(self): + # # note: This Loads the gunicorn config - has nothing to do with LiteLLM Proxy config + # config = { + # key: value + # for key, value in self.options.items() + # if key in self.cfg.settings and value is not None + # } + # for key, value in config.items(): + # self.cfg.set(key.lower(), value) - gunicorn_options = { - "bind": f"{host}:{port}", - "workers": num_workers, # default is 1 - "worker_class": "uvicorn.workers.UvicornWorker", - "preload": True, # Add the preload flag - } - from litellm.proxy.proxy_server import app + # def load(self): + # # gunicorn app function + # return self.application - StandaloneApplication(app=app, options=gunicorn_options).run() # Run gunicorn + # gunicorn_options = { + # "bind": f"{host}:{port}", + # "workers": num_workers, # default is 1 + # "worker_class": "uvicorn.workers.UvicornWorker", + # "preload": True, # Add the preload flag + # } + # from litellm.proxy.proxy_server import app + + # StandaloneApplication(app=app, options=gunicorn_options).run() # Run gunicorn if __name__ == "__main__": diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e74314193..8fd62cde2 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -7,6 +7,20 @@ import secrets, subprocess import hashlib, uuid import warnings import importlib +import warnings + + +def showwarning(message, category, filename, lineno, file=None, line=None): + traceback_info = f"{filename}:{lineno}: {category.__name__}: {message}\n" + if file is not None: + file.write(traceback_info) + + +warnings.showwarning = showwarning +warnings.filterwarnings("default", category=UserWarning) + +# Your client code here + messages: list = [] sys.path.insert( @@ -2510,10 +2524,12 @@ async def get_routes(): @router.on_event("shutdown") async def shutdown_event(): global prisma_client, master_key, user_custom_auth - if prisma_client: + if prisma_client is not None: verbose_proxy_logger.debug("Disconnecting from Prisma") await prisma_client.disconnect() + if litellm.cache is not None: + await litellm.cache.disconnect() ## RESET CUSTOM VARIABLES ## cleanup_router_config_variables() diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 11d4fda15..3250a2621 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -266,8 +266,9 @@ async def test_embedding_caching_azure_individual_items(): """ Tests caching for individual items in an embedding list - Assert if the same embeddingresponse object is returned for the duplicate item in 2 embedding list calls - + - Cache an item + - call aembedding(..) with the item + 1 unique item + - compare to a 2nd aembedding (...) with 2 unique items ``` embedding_1 = ["hey how's it going", "I'm doing well"] embedding_val_1 = embedding(...) @@ -280,31 +281,98 @@ async def test_embedding_caching_azure_individual_items(): """ litellm.cache = Cache() common_msg = f"hey how's it going {uuid.uuid4()}" - embedding_1 = [common_msg, "I'm doing well"] - embedding_2 = [common_msg, "I'm fine"] + common_msg_2 = f"hey how's it going {uuid.uuid4()}" + embedding_2 = [ + common_msg, + f"I'm fine {uuid.uuid4()}", + common_msg, + common_msg, + common_msg, + ] * 20 + embedding_2 = [ + common_msg, + f"I'm fine {uuid.uuid4()}", + common_msg, + common_msg, + common_msg, + ] * 20 + embedding_3 = [ + common_msg_2, + common_msg_2, + common_msg_2, + common_msg_2, + f"I'm fine {uuid.uuid4()}", + ] * 20 # make sure azure doesn't return cached 'i'm fine' responses embedding_val_1 = await aembedding( model="azure/azure-embedding-model", input=embedding_1, caching=True ) + second_response_start_time = time.time() embedding_val_2 = await aembedding( model="azure/azure-embedding-model", input=embedding_2, caching=True ) - print(f"embedding_val_2: {embedding_val_2}") - if ( - embedding_val_2["data"][0]["embedding"] - != embedding_val_1["data"][0]["embedding"] - ): - print(f"embedding1: {embedding_val_1}") - print(f"embedding2: {embedding_val_2}") - pytest.fail("Error occurred: Embedding caching failed") - if ( - embedding_val_2["data"][1]["embedding"] - == embedding_val_1["data"][1]["embedding"] - ): - print(f"embedding1: {embedding_val_1}") - print(f"embedding2: {embedding_val_2}") - pytest.fail("Error occurred: Embedding caching failed") + if embedding_val_2 is not None: + second_response_end_time = time.time() + second_response_time = second_response_end_time - second_response_start_time + + third_response_start_time = time.time() + embedding_val_3 = await aembedding( + model="azure/azure-embedding-model", input=embedding_3, cache={"no-cache": True} + ) + if embedding_val_3 is not None: + third_response_end_time = time.time() + third_response_time = third_response_end_time - third_response_start_time + + print(f"second_response_time: {second_response_time}") + print(f"third_response_time: {third_response_time}") + + assert ( + second_response_time < third_response_time - 0.5 + ) # make sure it's actually faster + raise Exception(f"it works {second_response_time} < {third_response_time}") + + +@pytest.mark.asyncio +async def test_redis_cache_basic(): + """ + Init redis client + - write to client + - read from client + """ + litellm.set_verbose = False + + random_number = random.randint( + 1, 100000 + ) # add a random number to ensure it's always adding / reading from cache + messages = [ + {"role": "user", "content": f"write a one sentence poem about: {random_number}"} + ] + litellm.cache = Cache( + type="redis", + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) + response1 = completion( + model="gpt-3.5-turbo", + messages=messages, + ) + + cache_key = litellm.cache.get_cache_key( + model="gpt-3.5-turbo", + messages=messages, + ) + print(f"cache_key: {cache_key}") + litellm.cache.add_cache(result=response1, cache_key=cache_key) + print(f"cache key pre async get: {cache_key}") + stored_val = await litellm.cache.async_get_cache( + model="gpt-3.5-turbo", + messages=messages, + ) + print(f"stored_val: {stored_val}") + assert stored_val["id"] == response1.id + raise Exception("it worked!") def test_redis_cache_completion(): diff --git a/litellm/utils.py b/litellm/utils.py index 8596b5d16..84a81649a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2214,8 +2214,13 @@ def client(original_function): ) # if caching is false, don't run this final_embedding_cached_response = None + if ( - (kwargs.get("caching", None) is None and litellm.cache is not None) + ( + kwargs.get("caching", None) is None + and kwargs.get("cache", None) is None + and litellm.cache is not None + ) or kwargs.get("caching", False) == True or ( kwargs.get("cache", None) is not None @@ -2234,12 +2239,13 @@ def client(original_function): kwargs["input"], list ): tasks = [] - embedding_kwargs = copy.deepcopy(kwargs) for idx, i in enumerate(kwargs["input"]): - embedding_kwargs["input"] = i + preset_cache_key = litellm.cache.get_cache_key( + *args, **{**kwargs, "input": i} + ) tasks.append( - litellm.cache._async_get_cache( - *args, **embedding_kwargs + litellm.cache.async_get_cache( + cache_key=preset_cache_key ) ) cached_result = await asyncio.gather(*tasks) @@ -2445,24 +2451,28 @@ def client(original_function): if isinstance(result, EmbeddingResponse) and isinstance( kwargs["input"], list ): - embedding_kwargs = copy.deepcopy(kwargs) for idx, i in enumerate(kwargs["input"]): + preset_cache_key = litellm.cache.get_cache_key( + *args, **{**kwargs, "input": i} + ) embedding_response = result.data[idx] - embedding_kwargs["input"] = i asyncio.create_task( - litellm.cache._async_add_cache( - embedding_response, *args, **embedding_kwargs + litellm.cache.async_add_cache( + embedding_response, + *args, + cache_key=preset_cache_key, ) ) + # pass else: asyncio.create_task( - litellm.cache._async_add_cache( + litellm.cache.async_add_cache( result.json(), *args, **kwargs ) ) else: asyncio.create_task( - litellm.cache._async_add_cache(result, *args, **kwargs) + litellm.cache.async_add_cache(result, *args, **kwargs) ) # LOG SUCCESS - handle streaming success logging in the _next_ object print_verbose( From 01df37d8cfdde067c8de7749d4a2e0ae7a89e059 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 13 Jan 2024 11:50:50 +0530 Subject: [PATCH 10/22] fix(caching.py): use bulk writes and blockconnectionpooling for reads from Redis --- litellm/_redis.py | 14 ++++++- litellm/caching.py | 95 +++++++++++++++++++++++++++++++++++++++++++--- litellm/utils.py | 19 ++++------ 3 files changed, 109 insertions(+), 19 deletions(-) diff --git a/litellm/_redis.py b/litellm/_redis.py index 36f4ef870..4484926d4 100644 --- a/litellm/_redis.py +++ b/litellm/_redis.py @@ -106,4 +106,16 @@ def get_redis_async_client(**env_overrides): redis_kwargs = _get_redis_client_logic(**env_overrides) if "url" in redis_kwargs and redis_kwargs["url"] is not None: return async_redis.Redis.from_url(**redis_kwargs) - return async_redis.Redis(socket_timeout=5, **redis_kwargs) + return async_redis.Redis( + socket_timeout=5, + **redis_kwargs, + ) + + +def get_redis_connection_pool(**env_overrides): + redis_kwargs = _get_redis_client_logic(**env_overrides) + if "url" in redis_kwargs and redis_kwargs["url"] is not None: + return async_redis.BlockingConnectionPool.from_url( + timeout=5, url=redis_kwargs["url"] + ) + return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs) diff --git a/litellm/caching.py b/litellm/caching.py index b89220e8d..de3b02297 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -8,7 +8,7 @@ # Thank you users! We ❤️ you! - Krrish & Ishaan import litellm -import time, logging +import time, logging, asyncio import json, traceback, ast, hashlib from typing import Optional, Literal, List, Union, Any from openai._models import BaseModel as OpenAIObject @@ -82,7 +82,7 @@ class RedisCache(BaseCache): # if users don't provider one, use the default litellm cache def __init__(self, host=None, port=None, password=None, **kwargs): - from ._redis import get_redis_client + from ._redis import get_redis_client, get_redis_connection_pool redis_kwargs = {} if host is not None: @@ -95,11 +95,20 @@ class RedisCache(BaseCache): redis_kwargs.update(kwargs) self.redis_client = get_redis_client(**redis_kwargs) self.redis_kwargs = redis_kwargs + self.async_redis_conn_pool = get_redis_connection_pool() + print_verbose( + f"Number of available connections init: {self.async_redis_conn_pool.pool.qsize()}" + ) def init_async_client(self): from ._redis import get_redis_async_client - return get_redis_async_client(**self.redis_kwargs) + print_verbose( + f"Number of available connections client_init: {self.async_redis_conn_pool.pool.qsize()}" + ) + return get_redis_async_client( + connection_pool=self.async_redis_conn_pool, **self.redis_kwargs + ) def set_cache(self, key, value, **kwargs): ttl = kwargs.get("ttl", None) @@ -111,16 +120,52 @@ class RedisCache(BaseCache): logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) async def async_set_cache(self, key, value, **kwargs): - async with self.init_async_client() as redis_client: + _redis_client = self.init_async_client() + async with _redis_client as redis_client: ttl = kwargs.get("ttl", None) print_verbose( f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" ) try: - await redis_client.set(name=key, value=str(value), ex=ttl) + await redis_client.set(name=key, value=json.dumps(value), ex=ttl) except Exception as e: # NON blocking - notify users Redis is throwing an exception logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) + print_verbose( + f"Number of available connections set_cache complete: {self.async_redis_conn_pool.pool.qsize()}" + ) + + async def async_set_cache_pipeline(self, cache_list, ttl=None): + """ + Use Redis Pipelines for bulk write operations + """ + _redis_client = self.init_async_client() + try: + async with _redis_client as redis_client: + async with redis_client.pipeline(transaction=True) as pipe: + # Iterate through each key-value pair in the cache_list and set them in the pipeline. + for cache_key, cache_value in cache_list: + print_verbose( + f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}" + ) + # Set the value with a TTL if it's provided. + if ttl is not None: + pipe.setex(cache_key, ttl, json.dumps(cache_value)) + else: + pipe.set(cache_key, json.dumps(cache_value)) + # Execute the pipeline and return the results. + results = await pipe.execute() + print_verbose( + f"Number of available connections set_cache complete: {self.async_redis_conn_pool.pool.qsize()}" + ) + + print_verbose(f"pipeline results: {results}") + # Optionally, you could process 'results' to make sure that all set operations were successful. + return results + except Exception as e: + print_verbose(f"Error occurred in pipeline write - {str(e)}") + # NON blocking - notify users Redis is throwing an exception + logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) def _get_cache_logic(self, cached_response: Any): """ @@ -152,7 +197,8 @@ class RedisCache(BaseCache): logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) async def async_get_cache(self, key, **kwargs): - async with self.init_async_client() as redis_client: + _redis_client = self.init_async_client() + async with _redis_client as redis_client: try: print_verbose(f"Get Redis Cache: key: {key}") cached_response = await redis_client.get(key) @@ -166,6 +212,10 @@ class RedisCache(BaseCache): traceback.print_exc() logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) + print_verbose( + f"Number of available connections get_cache complete: {self.async_redis_conn_pool.pool.qsize()}" + ) + def flush_cache(self): self.redis_client.flushall() @@ -684,6 +734,39 @@ class Cache: print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") traceback.print_exc() + async def async_add_cache_pipeline(self, result, *args, **kwargs): + """ + Async implementation of add_cache for Embedding calls + + Does a bulk write, to prevent using too many clients + """ + try: + cache_list = [] + for idx, i in enumerate(kwargs["input"]): + preset_cache_key = litellm.cache.get_cache_key( + *args, **{**kwargs, "input": i} + ) + embedding_response = result.data[idx] + cache_key, cached_data = self._add_cache_logic( + result=embedding_response, + cache_key=preset_cache_key, + *args, + **kwargs, + ) + cache_list.append((cache_key, cached_data)) + if hasattr(self.cache, "async_set_cache_pipeline"): + await self.cache.async_set_cache_pipeline(cache_list=cache_list) + else: + tasks = [] + for val in cache_list: + tasks.append( + self.cache.async_set_cache(cache_key, cached_data, **kwargs) + ) + await asyncio.gather(*tasks) + except Exception as e: + print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") + traceback.print_exc() + async def disconnect(self): if hasattr(self.cache, "disconnect"): await self.cache.disconnect() diff --git a/litellm/utils.py b/litellm/utils.py index 84a81649a..6059cdf85 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2346,6 +2346,9 @@ def client(original_function): kwargs["input"] = remaining_list if len(non_null_list) > 0: + print_verbose( + f"EMBEDDING CACHE HIT! - {len(non_null_list)}" + ) final_embedding_cached_response = EmbeddingResponse( model=kwargs.get("model"), data=[None] * len(original_kwargs_input), @@ -2451,19 +2454,11 @@ def client(original_function): if isinstance(result, EmbeddingResponse) and isinstance( kwargs["input"], list ): - for idx, i in enumerate(kwargs["input"]): - preset_cache_key = litellm.cache.get_cache_key( - *args, **{**kwargs, "input": i} + asyncio.create_task( + litellm.cache.async_add_cache_pipeline( + result, *args, **kwargs ) - embedding_response = result.data[idx] - asyncio.create_task( - litellm.cache.async_add_cache( - embedding_response, - *args, - cache_key=preset_cache_key, - ) - ) - # pass + ) else: asyncio.create_task( litellm.cache.async_add_cache( From c43a141889fb1fb198b5d4fd3877b692af332a3e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 13 Jan 2024 14:11:05 +0530 Subject: [PATCH 11/22] fix(caching.py): remove print verbose statement --- litellm/caching.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index de3b02297..5601fab2f 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -96,16 +96,10 @@ class RedisCache(BaseCache): self.redis_client = get_redis_client(**redis_kwargs) self.redis_kwargs = redis_kwargs self.async_redis_conn_pool = get_redis_connection_pool() - print_verbose( - f"Number of available connections init: {self.async_redis_conn_pool.pool.qsize()}" - ) def init_async_client(self): from ._redis import get_redis_async_client - print_verbose( - f"Number of available connections client_init: {self.async_redis_conn_pool.pool.qsize()}" - ) return get_redis_async_client( connection_pool=self.async_redis_conn_pool, **self.redis_kwargs ) @@ -131,9 +125,6 @@ class RedisCache(BaseCache): except Exception as e: # NON blocking - notify users Redis is throwing an exception logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) - print_verbose( - f"Number of available connections set_cache complete: {self.async_redis_conn_pool.pool.qsize()}" - ) async def async_set_cache_pipeline(self, cache_list, ttl=None): """ @@ -155,9 +146,6 @@ class RedisCache(BaseCache): pipe.set(cache_key, json.dumps(cache_value)) # Execute the pipeline and return the results. results = await pipe.execute() - print_verbose( - f"Number of available connections set_cache complete: {self.async_redis_conn_pool.pool.qsize()}" - ) print_verbose(f"pipeline results: {results}") # Optionally, you could process 'results' to make sure that all set operations were successful. @@ -212,10 +200,6 @@ class RedisCache(BaseCache): traceback.print_exc() logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) - print_verbose( - f"Number of available connections get_cache complete: {self.async_redis_conn_pool.pool.qsize()}" - ) - def flush_cache(self): self.redis_client.flushall() From 3789b37e979c3624797e8ebdb3e71b82fe21a009 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 13 Jan 2024 14:21:24 +0530 Subject: [PATCH 12/22] test(conftest.py): create an event loop if one doesn't exist --- litellm/tests/conftest.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/litellm/tests/conftest.py b/litellm/tests/conftest.py index 6b0df0f9a..0e8b656ab 100644 --- a/litellm/tests/conftest.py +++ b/litellm/tests/conftest.py @@ -21,6 +21,10 @@ def setup_and_teardown(): import litellm importlib.reload(litellm) + import asyncio + + loop = asyncio.get_event_loop_policy().new_event_loop() + asyncio.set_event_loop(loop) print(litellm) # from litellm import Router, completion, aembedding, acompletion, embedding yield From ebf1bc842c418dfe94d92aa1f6ccc18a7abc4952 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 13 Jan 2024 14:23:04 +0530 Subject: [PATCH 13/22] fix(conftest.py): create an event loop if one isn't made --- litellm/tests/conftest.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/litellm/tests/conftest.py b/litellm/tests/conftest.py index 0e8b656ab..4cd277b31 100644 --- a/litellm/tests/conftest.py +++ b/litellm/tests/conftest.py @@ -29,6 +29,10 @@ def setup_and_teardown(): # from litellm import Router, completion, aembedding, acompletion, embedding yield + # Teardown code (executes after the yield point) + loop.close() # Close the loop created earlier + asyncio.set_event_loop(None) # Remove the reference to the loop + def pytest_collection_modifyitems(config, items): # Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests From 7f83cca62cb757a685ec5b6bbc3ff58fd9c1fb1c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 13 Jan 2024 15:04:34 +0530 Subject: [PATCH 14/22] fix(caching.py): return updated kwargs from get_cache helper function --- litellm/caching.py | 6 +++--- litellm/tests/test_caching.py | 5 +---- litellm/utils.py | 6 +++++- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index 5601fab2f..59fc0ab67 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -678,7 +678,7 @@ class Cache: if k == "ttl": kwargs["ttl"] = v cached_data = {"timestamp": time.time(), "response": result} - return cache_key, cached_data + return cache_key, cached_data, kwargs else: raise Exception("cache key is None") except Exception as e: @@ -696,7 +696,7 @@ class Cache: None """ try: - cache_key, cached_data = self._add_cache_logic( + cache_key, cached_data, kwargs = self._add_cache_logic( result=result, *args, **kwargs ) self.cache.set_cache(cache_key, cached_data, **kwargs) @@ -710,7 +710,7 @@ class Cache: Async implementation of add_cache """ try: - cache_key, cached_data = self._add_cache_logic( + cache_key, cached_data, kwargs = self._add_cache_logic( result=result, *args, **kwargs ) await self.cache.async_set_cache(cache_key, cached_data, **kwargs) diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 3250a2621..695ad931a 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -109,10 +109,7 @@ def test_caching_with_cache_controls(): ) print(f"response1: {response1}") print(f"response2: {response2}") - assert ( - response2["choices"][0]["message"]["content"] - == response1["choices"][0]["message"]["content"] - ) + assert response2["id"] == response1["id"] except Exception as e: print(f"error occurred: {traceback.format_exc()}") pytest.fail(f"Error occurred: {e}") diff --git a/litellm/utils.py b/litellm/utils.py index 6059cdf85..3fee13937 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2029,7 +2029,11 @@ def client(original_function): # if caching is false or cache["no-cache"]==True, don't run this if ( ( - (kwargs.get("caching", None) is None and litellm.cache is not None) + ( + kwargs.get("caching", None) is None + and kwargs.get("cache", None) is None + and litellm.cache is not None + ) or kwargs.get("caching", False) == True or ( kwargs.get("cache", None) is not None From 40c952f7c2e36c2c0ed897edf89c12df6a9550aa Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 13 Jan 2024 15:33:57 +0530 Subject: [PATCH 15/22] fix(caching.py): fix async in-memory caching --- litellm/caching.py | 11 +++++++-- litellm/tests/test_caching.py | 44 ++++------------------------------- litellm/utils.py | 3 +++ 3 files changed, 16 insertions(+), 42 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index 59fc0ab67..594310b31 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -53,6 +53,13 @@ class InMemoryCache(BaseCache): async def async_set_cache(self, key, value, **kwargs): self.set_cache(key=key, value=value, **kwargs) + async def async_set_cache_pipeline(self, cache_list, ttl=None): + for cache_key, cache_value in cache_list: + if ttl is not None: + self.set_cache(key=cache_key, value=cache_value, ttl=ttl) + else: + self.set_cache(key=cache_key, value=cache_value) + def get_cache(self, key, **kwargs): if key in self.cache_dict: if key in self.ttl_dict: @@ -730,10 +737,10 @@ class Cache: preset_cache_key = litellm.cache.get_cache_key( *args, **{**kwargs, "input": i} ) + kwargs["cache_key"] = preset_cache_key embedding_response = result.data[idx] - cache_key, cached_data = self._add_cache_logic( + cache_key, cached_data, kwargs = self._add_cache_logic( result=embedding_response, - cache_key=preset_cache_key, *args, **kwargs, ) diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 695ad931a..89410598e 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -279,55 +279,20 @@ async def test_embedding_caching_azure_individual_items(): litellm.cache = Cache() common_msg = f"hey how's it going {uuid.uuid4()}" common_msg_2 = f"hey how's it going {uuid.uuid4()}" + embedding_1 = [common_msg] embedding_2 = [ common_msg, f"I'm fine {uuid.uuid4()}", - common_msg, - common_msg, - common_msg, - ] * 20 - embedding_2 = [ - common_msg, - f"I'm fine {uuid.uuid4()}", - common_msg, - common_msg, - common_msg, - ] * 20 - embedding_3 = [ - common_msg_2, - common_msg_2, - common_msg_2, - common_msg_2, - f"I'm fine {uuid.uuid4()}", - ] * 20 # make sure azure doesn't return cached 'i'm fine' responses + ] embedding_val_1 = await aembedding( model="azure/azure-embedding-model", input=embedding_1, caching=True ) - - second_response_start_time = time.time() embedding_val_2 = await aembedding( model="azure/azure-embedding-model", input=embedding_2, caching=True ) - if embedding_val_2 is not None: - second_response_end_time = time.time() - second_response_time = second_response_end_time - second_response_start_time - - third_response_start_time = time.time() - embedding_val_3 = await aembedding( - model="azure/azure-embedding-model", input=embedding_3, cache={"no-cache": True} - ) - if embedding_val_3 is not None: - third_response_end_time = time.time() - third_response_time = third_response_end_time - third_response_start_time - - print(f"second_response_time: {second_response_time}") - print(f"third_response_time: {third_response_time}") - - assert ( - second_response_time < third_response_time - 0.5 - ) # make sure it's actually faster - raise Exception(f"it works {second_response_time} < {third_response_time}") + print(f"embedding_val_2._hidden_params: {embedding_val_2._hidden_params}") + assert embedding_val_2._hidden_params["cache_hit"] == True @pytest.mark.asyncio @@ -369,7 +334,6 @@ async def test_redis_cache_basic(): ) print(f"stored_val: {stored_val}") assert stored_val["id"] == response1.id - raise Exception("it worked!") def test_redis_cache_completion(): diff --git a/litellm/utils.py b/litellm/utils.py index 3fee13937..344917118 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2357,6 +2357,9 @@ def client(original_function): model=kwargs.get("model"), data=[None] * len(original_kwargs_input), ) + final_embedding_cached_response._hidden_params[ + "cache_hit" + ] = True for val in non_null_list: idx, cr = val # (idx, cr) tuple From fb53d18d6a0221086fc254d57d9207bacedd47c4 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 13 Jan 2024 15:55:56 +0530 Subject: [PATCH 16/22] refactor(main.py): trigger rebuild --- litellm/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/main.py b/litellm/main.py index 413c8b6e9..e696c3c6a 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -10,6 +10,7 @@ import os, openai, sys, json, inspect, uuid, datetime, threading from typing import Any, Literal, Union from functools import partial + import dotenv, traceback, random, asyncio, time, contextvars from copy import deepcopy import httpx From 3c02ad8b9596e473e9972109a7add437fc8bbfe0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 13 Jan 2024 16:19:30 +0530 Subject: [PATCH 17/22] fix(utils.py): exclude s3 caching from individual item caching for embedding list can't bulk upload to s3, so this will slow down calls https://github.com/BerriAI/litellm/pull/1417 --- litellm/caching.py | 4 ++-- litellm/utils.py | 14 ++++++++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index 594310b31..c3fbaad6d 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -444,9 +444,9 @@ class Cache: """ if type == "redis": self.cache: BaseCache = RedisCache(host, port, password, **kwargs) - if type == "local": + elif type == "local": self.cache = InMemoryCache() - if type == "s3": + elif type == "s3": self.cache = S3Cache( s3_bucket_name=s3_bucket_name, s3_region_name=s3_region_name, diff --git a/litellm/utils.py b/litellm/utils.py index 344917118..15494c3ef 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -53,6 +53,7 @@ from .integrations.litedebugger import LiteDebugger from .proxy._types import KeyManagementSystem from openai import OpenAIError as OriginalError from openai._models import BaseModel as OpenAIObject +from .caching import S3Cache from .exceptions import ( AuthenticationError, BadRequestError, @@ -2338,6 +2339,10 @@ def client(original_function): call_type == CallTypes.aembedding.value and cached_result is not None and isinstance(cached_result, list) + and litellm.cache is not None + and not isinstance( + litellm.cache.cache, S3Cache + ) # s3 doesn't support bulk writing. Exclude. ): remaining_list = [] non_null_list = [] @@ -2458,8 +2463,13 @@ def client(original_function): if isinstance(result, litellm.ModelResponse) or isinstance( result, litellm.EmbeddingResponse ): - if isinstance(result, EmbeddingResponse) and isinstance( - kwargs["input"], list + if ( + isinstance(result, EmbeddingResponse) + and isinstance(kwargs["input"], list) + and litellm.cache is not None + and not isinstance( + litellm.cache.cache, S3Cache + ) # s3 doesn't support bulk writing. Exclude. ): asyncio.create_task( litellm.cache.async_add_cache_pipeline( From c2f674ebe0ec39e31887a16be80e38b5861fb409 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 3 Feb 2024 18:58:58 -0800 Subject: [PATCH 18/22] fix(utils.py): fix conditional check --- litellm/utils.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/litellm/utils.py b/litellm/utils.py index a10208564..ae8425d09 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2512,12 +2512,19 @@ def client(original_function): ) ) cached_result = await asyncio.gather(*tasks) + ## check if cached result is None ## + if cached_result is not None and isinstance( + cached_result, list + ): + if len(cached_result) == 1 and cached_result[0] is None: + cached_result = None else: preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) kwargs[ "preset_cache_key" ] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key cached_result = litellm.cache.get_cache(*args, **kwargs) + if cached_result is not None and not isinstance( cached_result, list ): @@ -2611,7 +2618,6 @@ def client(original_function): non_null_list.append((idx, cr)) original_kwargs_input = kwargs["input"] kwargs["input"] = remaining_list - if len(non_null_list) > 0: print_verbose( f"EMBEDDING CACHE HIT! - {len(non_null_list)}" @@ -2628,7 +2634,6 @@ def client(original_function): idx, cr = val # (idx, cr) tuple if cr is not None: final_embedding_cached_response.data[idx] = cr - if len(remaining_list) == 0: # LOG SUCCESS cache_hit = True @@ -2769,7 +2774,8 @@ def client(original_function): result._response_ms = ( end_time - start_time ).total_seconds() * 1000 # return response latency in ms like openai - elif ( + + if ( isinstance(result, EmbeddingResponse) and final_embedding_cached_response is not None ): @@ -2783,6 +2789,10 @@ def client(original_function): final_data_list.append(item) final_embedding_cached_response.data = final_data_list + final_embedding_cached_response._hidden_params["cache_hit"] = True + final_embedding_cached_response._response_ms = ( + end_time - start_time + ).total_seconds() * 1000 return final_embedding_cached_response return result From c49c88c8e552064c89611fc4f14cf2d4ec42fccd Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 3 Feb 2024 19:22:48 -0800 Subject: [PATCH 19/22] fix(utils.py): route together ai calls to openai client together ai is now openai-compatible n --- litellm/__init__.py | 2 ++ litellm/llms/openai.py | 4 ++-- litellm/llms/together_ai.py | 4 ++++ litellm/main.py | 4 ++++ litellm/tests/test_completion.py | 3 ++- litellm/utils.py | 10 +++++++++- 6 files changed, 23 insertions(+), 4 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 6bdfe5e10..3f2a1e4b4 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -285,6 +285,7 @@ openai_compatible_endpoints: List = [ "api.endpoints.anyscale.com/v1", "api.deepinfra.com/v1/openai", "api.mistral.ai/v1", + "api.together.xyz/v1", ] # this is maintained for Exception Mapping @@ -294,6 +295,7 @@ openai_compatible_providers: List = [ "deepinfra", "perplexity", "xinference", + "together_ai", ] diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 7121d7bc7..3f151d1a9 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -440,8 +440,8 @@ class OpenAIChatCompletion(BaseLLM): input=data["messages"], api_key=api_key, additional_args={ - "headers": headers, - "api_base": api_base, + "headers": {"Authorization": f"Bearer {openai_client.api_key}"}, + "api_base": openai_client._base_url._uri_reference, "acompletion": False, "complete_input_dict": data, }, diff --git a/litellm/llms/together_ai.py b/litellm/llms/together_ai.py index d4b85e9ca..15ed29916 100644 --- a/litellm/llms/together_ai.py +++ b/litellm/llms/together_ai.py @@ -1,3 +1,7 @@ +""" +Deprecated. We now do together ai calls via the openai client. +Reference: https://docs.together.ai/docs/openai-api-compatibility +""" import os, types import json from enum import Enum diff --git a/litellm/main.py b/litellm/main.py index 2df5de89c..bc33a69e5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -791,6 +791,7 @@ def completion( or custom_llm_provider == "anyscale" or custom_llm_provider == "mistral" or custom_llm_provider == "openai" + or custom_llm_provider == "together_ai" or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo ): # allow user to make an openai call with a custom base # note: if a user sets a custom base - we should ensure this works @@ -1330,6 +1331,9 @@ def completion( or ("togethercomputer" in model) or (model in litellm.together_ai_models) ): + """ + Deprecated. We now do together ai calls via the openai client - https://docs.together.ai/docs/openai-api-compatibility + """ custom_llm_provider = "together_ai" together_ai_key = ( api_key diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 54640b54b..d98745d0b 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1994,11 +1994,12 @@ def test_completion_palm_stream(): def test_completion_together_ai_stream(): + litellm.set_verbose = True user_message = "Write 1pg about YC & litellm" messages = [{"content": user_message, "role": "user"}] try: response = completion( - model="together_ai/mistralai/Mistral-7B-Instruct-v0.1", + model="together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1", messages=messages, stream=True, max_tokens=5, diff --git a/litellm/utils.py b/litellm/utils.py index ae8425d09..6aba17f95 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -863,6 +863,7 @@ class Logging: curl_command += additional_args.get("request_str", None) elif api_base == "": curl_command = self.model_call_details + print_verbose(f"\033[92m{curl_command}\033[0m\n") verbose_logger.info(f"\033[92m{curl_command}\033[0m\n") if self.logger_fn and callable(self.logger_fn): try: @@ -4043,7 +4044,7 @@ def get_optional_params( _check_valid_arg(supported_params=supported_params) if stream: - optional_params["stream_tokens"] = stream + optional_params["stream"] = stream if temperature is not None: optional_params["temperature"] = temperature if top_p is not None: @@ -4677,6 +4678,13 @@ def get_llm_provider( # voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1 api_base = "https://api.voyageai.com/v1" dynamic_api_key = get_secret("VOYAGE_API_KEY") + elif custom_llm_provider == "together_ai": + api_base = "https://api.together.xyz/v1" + dynamic_api_key = ( + get_secret("TOGETHER_API_KEY") + or get_secret("TOGETHER_AI_API_KEY") + or get_secret("TOGETHERAI_API_KEY") + ) return model, custom_llm_provider, dynamic_api_key, api_base elif model.split("/", 1)[0] in litellm.provider_list: custom_llm_provider = model.split("/", 1)[0] From efb6123d28cd46ba16cabc82ab18405038899cfe Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 3 Feb 2024 19:35:09 -0800 Subject: [PATCH 20/22] fix(utils.py): support get_secret("TOGETHER_AI_TOKEN") --- litellm/tests/test_completion.py | 4 ++-- litellm/utils.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index d98745d0b..2ef525e9f 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -37,7 +37,7 @@ def test_completion_custom_provider_model_name(): try: litellm.cache = None response = completion( - model="together_ai/mistralai/Mistral-7B-Instruct-v0.1", + model="together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1", messages=messages, logger_fn=logger_fn, ) @@ -1369,7 +1369,7 @@ def test_customprompt_together_ai(): print(litellm.success_callback) print(litellm._async_success_callback) response = completion( - model="together_ai/mistralai/Mistral-7B-Instruct-v0.1", + model="together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1", messages=messages, roles={ "system": { diff --git a/litellm/utils.py b/litellm/utils.py index 6aba17f95..d67ee37e3 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4684,6 +4684,7 @@ def get_llm_provider( get_secret("TOGETHER_API_KEY") or get_secret("TOGETHER_AI_API_KEY") or get_secret("TOGETHERAI_API_KEY") + or get_secret("TOGETHER_AI_TOKEN") ) return model, custom_llm_provider, dynamic_api_key, api_base elif model.split("/", 1)[0] in litellm.provider_list: From 25a0e1572731655ad23f1a512ffd6ab2135cb19d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 3 Feb 2024 19:59:32 -0800 Subject: [PATCH 21/22] fix(utils.py): support time based pricing for openai-compatible together ai --- litellm/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/litellm/utils.py b/litellm/utils.py index d67ee37e3..bcc5dc7df 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3455,7 +3455,11 @@ def completion_cost( else: raise Exception(f"Model={model} not found in completion cost model map") # Calculate cost based on prompt_tokens, completion_tokens - if "togethercomputer" in model or "together_ai" in model: + if ( + "togethercomputer" in model + or "together_ai" in model + or custom_llm_provider == "together_ai" + ): # together ai prices based on size of llm # get_model_params_and_category takes a model name and returns the category of LLM size it is in model_prices_and_context_window.json model = get_model_params_and_category(model) From b47b2837ebba020f44e3367c1f343ee2f08e3785 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 3 Feb 2024 20:34:05 -0800 Subject: [PATCH 22/22] test(test_parallel_request_limiter.py): fix test --- litellm/tests/test_parallel_request_limiter.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/litellm/tests/test_parallel_request_limiter.py b/litellm/tests/test_parallel_request_limiter.py index 1155e5794..27d81356f 100644 --- a/litellm/tests/test_parallel_request_limiter.py +++ b/litellm/tests/test_parallel_request_limiter.py @@ -525,17 +525,12 @@ async def test_streaming_router_tpm_limit(): continue await asyncio.sleep(5) # success is done in a separate thread - try: - await parallel_request_handler.async_pre_call_hook( - user_api_key_dict=user_api_key_dict, - cache=local_cache, - data={}, - call_type="", - ) - - pytest.fail(f"Expected call to fail") - except Exception as e: - assert e.status_code == 429 + assert ( + parallel_request_handler.user_api_key_cache.get_cache( + key=request_count_api_key + )["current_tpm"] + > 0 + ) @pytest.mark.asyncio