diff --git a/litellm/__init__.py b/litellm/__init__.py index 6bdfe5e10..3f2a1e4b4 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -285,6 +285,7 @@ openai_compatible_endpoints: List = [ "api.endpoints.anyscale.com/v1", "api.deepinfra.com/v1/openai", "api.mistral.ai/v1", + "api.together.xyz/v1", ] # this is maintained for Exception Mapping @@ -294,6 +295,7 @@ openai_compatible_providers: List = [ "deepinfra", "perplexity", "xinference", + "together_ai", ] diff --git a/litellm/_redis.py b/litellm/_redis.py index bee73f134..4484926d4 100644 --- a/litellm/_redis.py +++ b/litellm/_redis.py @@ -11,6 +11,7 @@ import os import inspect import redis, litellm +import redis.asyncio as async_redis from typing import List, Optional @@ -67,7 +68,10 @@ def get_redis_url_from_environment(): ) -def get_redis_client(**env_overrides): +def _get_redis_client_logic(**env_overrides): + """ + Common functionality across sync + async redis client implementations + """ ### check if "os.environ/" passed in for k, v in env_overrides.items(): if isinstance(v, str) and v.startswith("os.environ/"): @@ -85,9 +89,33 @@ def get_redis_client(**env_overrides): redis_kwargs.pop("port", None) redis_kwargs.pop("db", None) redis_kwargs.pop("password", None) - - return redis.Redis.from_url(**redis_kwargs) elif "host" not in redis_kwargs or redis_kwargs["host"] is None: raise ValueError("Either 'host' or 'url' must be specified for redis.") litellm.print_verbose(f"redis_kwargs: {redis_kwargs}") + return redis_kwargs + + +def get_redis_client(**env_overrides): + redis_kwargs = _get_redis_client_logic(**env_overrides) + if "url" in redis_kwargs and redis_kwargs["url"] is not None: + return redis.Redis.from_url(**redis_kwargs) return redis.Redis(**redis_kwargs) + + +def get_redis_async_client(**env_overrides): + redis_kwargs = _get_redis_client_logic(**env_overrides) + if "url" in redis_kwargs and redis_kwargs["url"] is not None: + return async_redis.Redis.from_url(**redis_kwargs) + return async_redis.Redis( + socket_timeout=5, + **redis_kwargs, + ) + + +def get_redis_connection_pool(**env_overrides): + redis_kwargs = _get_redis_client_logic(**env_overrides) + if "url" in redis_kwargs and redis_kwargs["url"] is not None: + return async_redis.BlockingConnectionPool.from_url( + timeout=5, url=redis_kwargs["url"] + ) + return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs) diff --git a/litellm/caching.py b/litellm/caching.py index 257bb1ca5..d0721fe9a 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -8,7 +8,7 @@ # Thank you users! We ❤️ you! - Krrish & Ishaan import litellm -import time, logging +import time, logging, asyncio import json, traceback, ast, hashlib from typing import Optional, Literal, List, Union, Any from openai._models import BaseModel as OpenAIObject @@ -28,9 +28,18 @@ class BaseCache: def set_cache(self, key, value, **kwargs): raise NotImplementedError + async def async_set_cache(self, key, value, **kwargs): + raise NotImplementedError + def get_cache(self, key, **kwargs): raise NotImplementedError + async def async_get_cache(self, key, **kwargs): + raise NotImplementedError + + async def disconnect(self): + raise NotImplementedError + class InMemoryCache(BaseCache): def __init__(self): @@ -43,6 +52,16 @@ class InMemoryCache(BaseCache): if "ttl" in kwargs: self.ttl_dict[key] = time.time() + kwargs["ttl"] + async def async_set_cache(self, key, value, **kwargs): + self.set_cache(key=key, value=value, **kwargs) + + async def async_set_cache_pipeline(self, cache_list, ttl=None): + for cache_key, cache_value in cache_list: + if ttl is not None: + self.set_cache(key=cache_key, value=cache_value, ttl=ttl) + else: + self.set_cache(key=cache_key, value=cache_value) + def get_cache(self, key, **kwargs): if key in self.cache_dict: if key in self.ttl_dict: @@ -57,21 +76,27 @@ class InMemoryCache(BaseCache): return cached_response return None + async def async_get_cache(self, key, **kwargs): + return self.get_cache(key=key, **kwargs) + def flush_cache(self): self.cache_dict.clear() self.ttl_dict.clear() + + async def disconnect(self): + pass + def delete_cache(self, key): self.cache_dict.pop(key, None) self.ttl_dict.pop(key, None) class RedisCache(BaseCache): - def __init__(self, host=None, port=None, password=None, **kwargs): - import redis + # if users don't provider one, use the default litellm cache - # if users don't provider one, use the default litellm cache - from ._redis import get_redis_client + def __init__(self, host=None, port=None, password=None, **kwargs): + from ._redis import get_redis_client, get_redis_connection_pool redis_kwargs = {} if host is not None: @@ -82,18 +107,84 @@ class RedisCache(BaseCache): redis_kwargs["password"] = password redis_kwargs.update(kwargs) - self.redis_client = get_redis_client(**redis_kwargs) + self.redis_kwargs = redis_kwargs + self.async_redis_conn_pool = get_redis_connection_pool() + + def init_async_client(self): + from ._redis import get_redis_async_client + + return get_redis_async_client( + connection_pool=self.async_redis_conn_pool, **self.redis_kwargs + ) def set_cache(self, key, value, **kwargs): ttl = kwargs.get("ttl", None) - print_verbose(f"Set Redis Cache: key: {key}\nValue {value}") + print_verbose(f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}") try: self.redis_client.set(name=key, value=str(value), ex=ttl) except Exception as e: # NON blocking - notify users Redis is throwing an exception logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) + async def async_set_cache(self, key, value, **kwargs): + _redis_client = self.init_async_client() + async with _redis_client as redis_client: + ttl = kwargs.get("ttl", None) + print_verbose( + f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" + ) + try: + await redis_client.set(name=key, value=json.dumps(value), ex=ttl) + except Exception as e: + # NON blocking - notify users Redis is throwing an exception + logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) + + async def async_set_cache_pipeline(self, cache_list, ttl=None): + """ + Use Redis Pipelines for bulk write operations + """ + _redis_client = self.init_async_client() + try: + async with _redis_client as redis_client: + async with redis_client.pipeline(transaction=True) as pipe: + # Iterate through each key-value pair in the cache_list and set them in the pipeline. + for cache_key, cache_value in cache_list: + print_verbose( + f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}" + ) + # Set the value with a TTL if it's provided. + if ttl is not None: + pipe.setex(cache_key, ttl, json.dumps(cache_value)) + else: + pipe.set(cache_key, json.dumps(cache_value)) + # Execute the pipeline and return the results. + results = await pipe.execute() + + print_verbose(f"pipeline results: {results}") + # Optionally, you could process 'results' to make sure that all set operations were successful. + return results + except Exception as e: + print_verbose(f"Error occurred in pipeline write - {str(e)}") + # NON blocking - notify users Redis is throwing an exception + logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) + + def _get_cache_logic(self, cached_response: Any): + """ + Common 'get_cache_logic' across sync + async redis client implementations + """ + if cached_response is None: + return cached_response + # cached_response is in `b{} convert it to ModelResponse + cached_response = cached_response.decode("utf-8") # Convert bytes to string + try: + cached_response = json.loads( + cached_response + ) # Convert string to dictionary + except: + cached_response = ast.literal_eval(cached_response) + return cached_response + def get_cache(self, key, **kwargs): try: print_verbose(f"Get Redis Cache: key: {key}") @@ -101,30 +192,40 @@ class RedisCache(BaseCache): print_verbose( f"Got Redis Cache: key: {key}, cached_response {cached_response}" ) - if cached_response != None: - # cached_response is in `b{} convert it to ModelResponse - cached_response = cached_response.decode( - "utf-8" - ) # Convert bytes to string - try: - cached_response = json.loads( - cached_response - ) # Convert string to dictionary - except: - cached_response = ast.literal_eval(cached_response) - return cached_response + return self._get_cache_logic(cached_response=cached_response) except Exception as e: # NON blocking - notify users Redis is throwing an exception traceback.print_exc() logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) + async def async_get_cache(self, key, **kwargs): + _redis_client = self.init_async_client() + async with _redis_client as redis_client: + try: + print_verbose(f"Get Redis Cache: key: {key}") + cached_response = await redis_client.get(key) + print_verbose( + f"Got Async Redis Cache: key: {key}, cached_response {cached_response}" + ) + response = self._get_cache_logic(cached_response=cached_response) + return response + except Exception as e: + # NON blocking - notify users Redis is throwing an exception + traceback.print_exc() + logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) + def flush_cache(self): self.redis_client.flushall() + + async def disconnect(self): + pass + def delete_cache(self, key): self.redis_client.delete(key) + class S3Cache(BaseCache): def __init__( self, @@ -202,6 +303,9 @@ class S3Cache(BaseCache): # NON blocking - notify users S3 is throwing an exception print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}") + async def async_set_cache(self, key, value, **kwargs): + self.set_cache(key=key, value=value, **kwargs) + def get_cache(self, key, **kwargs): import boto3, botocore @@ -244,6 +348,9 @@ class S3Cache(BaseCache): traceback.print_exc() print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}") + async def async_get_cache(self, key, **kwargs): + return self.get_cache(key=key, **kwargs) + def flush_cache(self): pass @@ -361,9 +468,9 @@ class Cache: """ if type == "redis": self.cache: BaseCache = RedisCache(host, port, password, **kwargs) - if type == "local": + elif type == "local": self.cache = InMemoryCache() - if type == "s3": + elif type == "s3": self.cache = S3Cache( s3_bucket_name=s3_bucket_name, s3_region_name=s3_region_name, @@ -489,6 +596,45 @@ class Cache: } time.sleep(0.02) + def _get_cache_logic( + self, + cached_result: Optional[Any], + max_age: Optional[float], + ): + """ + Common get cache logic across sync + async implementations + """ + # Check if a timestamp was stored with the cached response + if ( + cached_result is not None + and isinstance(cached_result, dict) + and "timestamp" in cached_result + ): + timestamp = cached_result["timestamp"] + current_time = time.time() + + # Calculate age of the cached response + response_age = current_time - timestamp + + # Check if the cached response is older than the max-age + if max_age is not None and response_age > max_age: + return None # Cached response is too old + + # If the response is fresh, or there's no max-age requirement, return the cached response + # cached_response is in `b{} convert it to ModelResponse + cached_response = cached_result.get("response") + try: + if isinstance(cached_response, dict): + pass + else: + cached_response = json.loads( + cached_response # type: ignore + ) # Convert string to dictionary + except: + cached_response = ast.literal_eval(cached_response) # type: ignore + return cached_response + return cached_result + def get_cache(self, *args, **kwargs): """ Retrieves the cached result for the given arguments. @@ -511,54 +657,40 @@ class Cache: "s-max-age", cache_control_args.get("s-maxage", float("inf")) ) cached_result = self.cache.get_cache(cache_key) - # Check if a timestamp was stored with the cached response - if ( - cached_result is not None - and isinstance(cached_result, dict) - and "timestamp" in cached_result - and max_age is not None - ): - timestamp = cached_result["timestamp"] - current_time = time.time() - - # Calculate age of the cached response - response_age = current_time - timestamp - - # Check if the cached response is older than the max-age - if response_age > max_age: - print_verbose( - f"Cached response for key {cache_key} is too old. Max-age: {max_age}s, Age: {response_age}s" - ) - return None # Cached response is too old - - # If the response is fresh, or there's no max-age requirement, return the cached response - # cached_response is in `b{} convert it to ModelResponse - cached_response = cached_result.get("response") - try: - if isinstance(cached_response, dict): - pass - else: - cached_response = json.loads( - cached_response - ) # Convert string to dictionary - except: - cached_response = ast.literal_eval(cached_response) - return cached_response - return cached_result + return self._get_cache_logic( + cached_result=cached_result, max_age=max_age + ) except Exception as e: print_verbose(f"An exception occurred: {traceback.format_exc()}") return None - def add_cache(self, result, *args, **kwargs): + async def async_get_cache(self, *args, **kwargs): """ - Adds a result to the cache. + Async get cache implementation. - Args: - *args: args to litellm.completion() or embedding() - **kwargs: kwargs to litellm.completion() or embedding() + Used for embedding calls in async wrapper + """ + try: # never block execution + if "cache_key" in kwargs: + cache_key = kwargs["cache_key"] + else: + cache_key = self.get_cache_key(*args, **kwargs) + if cache_key is not None: + cache_control_args = kwargs.get("cache", {}) + max_age = cache_control_args.get( + "s-max-age", cache_control_args.get("s-maxage", float("inf")) + ) + cached_result = await self.cache.async_get_cache(cache_key) + return self._get_cache_logic( + cached_result=cached_result, max_age=max_age + ) + except Exception as e: + print_verbose(f"An exception occurred: {traceback.format_exc()}") + return None - Returns: - None + def _add_cache_logic(self, result, *args, **kwargs): + """ + Common implementation across sync + async add_cache functions """ try: if "cache_key" in kwargs: @@ -577,14 +709,82 @@ class Cache: if k == "ttl": kwargs["ttl"] = v cached_data = {"timestamp": time.time(), "response": result} - self.cache.set_cache(cache_key, cached_data, **kwargs) + return cache_key, cached_data, kwargs + else: + raise Exception("cache key is None") + except Exception as e: + raise e + + def add_cache(self, result, *args, **kwargs): + """ + Adds a result to the cache. + + Args: + *args: args to litellm.completion() or embedding() + **kwargs: kwargs to litellm.completion() or embedding() + + Returns: + None + """ + try: + cache_key, cached_data, kwargs = self._add_cache_logic( + result=result, *args, **kwargs + ) + self.cache.set_cache(cache_key, cached_data, **kwargs) except Exception as e: print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") traceback.print_exc() pass - async def _async_add_cache(self, result, *args, **kwargs): - self.add_cache(result, *args, **kwargs) + async def async_add_cache(self, result, *args, **kwargs): + """ + Async implementation of add_cache + """ + try: + cache_key, cached_data, kwargs = self._add_cache_logic( + result=result, *args, **kwargs + ) + await self.cache.async_set_cache(cache_key, cached_data, **kwargs) + except Exception as e: + print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") + traceback.print_exc() + + async def async_add_cache_pipeline(self, result, *args, **kwargs): + """ + Async implementation of add_cache for Embedding calls + + Does a bulk write, to prevent using too many clients + """ + try: + cache_list = [] + for idx, i in enumerate(kwargs["input"]): + preset_cache_key = litellm.cache.get_cache_key( + *args, **{**kwargs, "input": i} + ) + kwargs["cache_key"] = preset_cache_key + embedding_response = result.data[idx] + cache_key, cached_data, kwargs = self._add_cache_logic( + result=embedding_response, + *args, + **kwargs, + ) + cache_list.append((cache_key, cached_data)) + if hasattr(self.cache, "async_set_cache_pipeline"): + await self.cache.async_set_cache_pipeline(cache_list=cache_list) + else: + tasks = [] + for val in cache_list: + tasks.append( + self.cache.async_set_cache(cache_key, cached_data, **kwargs) + ) + await asyncio.gather(*tasks) + except Exception as e: + print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") + traceback.print_exc() + + async def disconnect(self): + if hasattr(self.cache, "disconnect"): + await self.cache.disconnect() def enable_cache( diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 7121d7bc7..3f151d1a9 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -440,8 +440,8 @@ class OpenAIChatCompletion(BaseLLM): input=data["messages"], api_key=api_key, additional_args={ - "headers": headers, - "api_base": api_base, + "headers": {"Authorization": f"Bearer {openai_client.api_key}"}, + "api_base": openai_client._base_url._uri_reference, "acompletion": False, "complete_input_dict": data, }, diff --git a/litellm/llms/together_ai.py b/litellm/llms/together_ai.py index d4b85e9ca..15ed29916 100644 --- a/litellm/llms/together_ai.py +++ b/litellm/llms/together_ai.py @@ -1,3 +1,7 @@ +""" +Deprecated. We now do together ai calls via the openai client. +Reference: https://docs.together.ai/docs/openai-api-compatibility +""" import os, types import json from enum import Enum diff --git a/litellm/main.py b/litellm/main.py index 401f60d42..bc33a69e5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -10,6 +10,7 @@ import os, openai, sys, json, inspect, uuid, datetime, threading from typing import Any, Literal, Union from functools import partial + import dotenv, traceback, random, asyncio, time, contextvars from copy import deepcopy import httpx @@ -234,6 +235,9 @@ async def acompletion( "model_list": model_list, "acompletion": True, # assuming this is a required parameter } + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=completion_kwargs.get("base_url", None) + ) try: # Use a partial function to pass your keyword arguments func = partial(completion, **completion_kwargs, **kwargs) @@ -245,7 +249,6 @@ async def acompletion( _, custom_llm_provider, _, _ = get_llm_provider( model=model, api_base=kwargs.get("api_base", None) ) - if ( custom_llm_provider == "openai" or custom_llm_provider == "azure" @@ -788,6 +791,7 @@ def completion( or custom_llm_provider == "anyscale" or custom_llm_provider == "mistral" or custom_llm_provider == "openai" + or custom_llm_provider == "together_ai" or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo ): # allow user to make an openai call with a custom base # note: if a user sets a custom base - we should ensure this works @@ -1327,6 +1331,9 @@ def completion( or ("togethercomputer" in model) or (model in litellm.together_ai_models) ): + """ + Deprecated. We now do together ai calls via the openai client - https://docs.together.ai/docs/openai-api-compatibility + """ custom_llm_provider = "together_ai" together_ai_key = ( api_key diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 82bdbd625..890cf5294 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -380,7 +380,7 @@ def run_server( import gunicorn.app.base except: raise ImportError( - "Uvicorn, gunicorn needs to be imported. Run - `pip 'litellm[proxy]'`" + "uvicorn, gunicorn needs to be imported. Run - `pip install 'litellm[proxy]'`" ) if config is not None: @@ -444,6 +444,7 @@ def run_server( ) if port == 8000 and is_port_in_use(port): port = random.randint(1024, 49152) + from litellm.proxy.proxy_server import app if run_gunicorn == False: @@ -521,5 +522,6 @@ def run_server( ).run() # Run gunicorn + if __name__ == "__main__": run_server() diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index aa2242211..0501ec746 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -7,6 +7,20 @@ import secrets, subprocess import hashlib, uuid import warnings import importlib +import warnings + + +def showwarning(message, category, filename, lineno, file=None, line=None): + traceback_info = f"{filename}:{lineno}: {category.__name__}: {message}\n" + if file is not None: + file.write(traceback_info) + + +warnings.showwarning = showwarning +warnings.filterwarnings("default", category=UserWarning) + +# Your client code here + messages: list = [] sys.path.insert( @@ -4053,9 +4067,12 @@ def _has_user_setup_sso(): async def shutdown_event(): global prisma_client, master_key, user_custom_auth, user_custom_key_generate if prisma_client: + verbose_proxy_logger.debug("Disconnecting from Prisma") await prisma_client.disconnect() + if litellm.cache is not None: + await litellm.cache.disconnect() ## RESET CUSTOM VARIABLES ## cleanup_router_config_variables() diff --git a/litellm/tests/conftest.py b/litellm/tests/conftest.py index 6b0df0f9a..4cd277b31 100644 --- a/litellm/tests/conftest.py +++ b/litellm/tests/conftest.py @@ -21,10 +21,18 @@ def setup_and_teardown(): import litellm importlib.reload(litellm) + import asyncio + + loop = asyncio.get_event_loop_policy().new_event_loop() + asyncio.set_event_loop(loop) print(litellm) # from litellm import Router, completion, aembedding, acompletion, embedding yield + # Teardown code (executes after the yield point) + loop.close() # Close the loop created earlier + asyncio.set_event_loop(None) # Remove the reference to the loop + def pytest_collection_modifyitems(config, items): # Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index efe7a5443..468ab6f80 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -11,10 +11,10 @@ sys.path.insert( ) # Adds the parent directory to the system path import pytest import litellm -from litellm import embedding, completion +from litellm import embedding, completion, aembedding from litellm.caching import Cache import random -import hashlib +import hashlib, asyncio # litellm.set_verbose=True @@ -106,10 +106,7 @@ def test_caching_with_cache_controls(): ) print(f"response1: {response1}") print(f"response2: {response2}") - assert ( - response2["choices"][0]["message"]["content"] - == response1["choices"][0]["message"]["content"] - ) + assert response2["id"] == response1["id"] except Exception as e: print(f"error occurred: {traceback.format_exc()}") pytest.fail(f"Error occurred: {e}") @@ -259,6 +256,84 @@ def test_embedding_caching_azure(): # test_embedding_caching_azure() +@pytest.mark.asyncio +async def test_embedding_caching_azure_individual_items(): + """ + Tests caching for individual items in an embedding list + + - Cache an item + - call aembedding(..) with the item + 1 unique item + - compare to a 2nd aembedding (...) with 2 unique items + ``` + embedding_1 = ["hey how's it going", "I'm doing well"] + embedding_val_1 = embedding(...) + + embedding_2 = ["hey how's it going", "I'm fine"] + embedding_val_2 = embedding(...) + + assert embedding_val_1[0]["id"] == embedding_val_2[0]["id"] + ``` + """ + litellm.cache = Cache() + common_msg = f"hey how's it going {uuid.uuid4()}" + common_msg_2 = f"hey how's it going {uuid.uuid4()}" + embedding_1 = [common_msg] + embedding_2 = [ + common_msg, + f"I'm fine {uuid.uuid4()}", + ] + + embedding_val_1 = await aembedding( + model="azure/azure-embedding-model", input=embedding_1, caching=True + ) + embedding_val_2 = await aembedding( + model="azure/azure-embedding-model", input=embedding_2, caching=True + ) + print(f"embedding_val_2._hidden_params: {embedding_val_2._hidden_params}") + assert embedding_val_2._hidden_params["cache_hit"] == True + + +@pytest.mark.asyncio +async def test_redis_cache_basic(): + """ + Init redis client + - write to client + - read from client + """ + litellm.set_verbose = False + + random_number = random.randint( + 1, 100000 + ) # add a random number to ensure it's always adding / reading from cache + messages = [ + {"role": "user", "content": f"write a one sentence poem about: {random_number}"} + ] + litellm.cache = Cache( + type="redis", + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) + response1 = completion( + model="gpt-3.5-turbo", + messages=messages, + ) + + cache_key = litellm.cache.get_cache_key( + model="gpt-3.5-turbo", + messages=messages, + ) + print(f"cache_key: {cache_key}") + litellm.cache.add_cache(result=response1, cache_key=cache_key) + print(f"cache key pre async get: {cache_key}") + stored_val = await litellm.cache.async_get_cache( + model="gpt-3.5-turbo", + messages=messages, + ) + print(f"stored_val: {stored_val}") + assert stored_val["id"] == response1.id + + def test_redis_cache_completion(): litellm.set_verbose = False @@ -406,7 +481,7 @@ def test_redis_cache_acompletion_stream(): import asyncio try: - litellm.set_verbose = True + litellm.set_verbose = False random_word = generate_random_word() messages = [ { @@ -434,7 +509,6 @@ def test_redis_cache_acompletion_stream(): stream=True, ) async for chunk in response1: - print(chunk) response_1_content += chunk.choices[0].delta.content or "" print(response_1_content) @@ -452,7 +526,6 @@ def test_redis_cache_acompletion_stream(): stream=True, ) async for chunk in response2: - print(chunk) response_2_content += chunk.choices[0].delta.content or "" print(response_2_content) @@ -914,101 +987,3 @@ def test_cache_context_managers(): # test_cache_context_managers() - -# test_custom_redis_cache_params() - -# def test_redis_cache_with_ttl(): -# cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) -# sample_model_response_object_str = """{ -# "choices": [ -# { -# "finish_reason": "stop", -# "index": 0, -# "message": { -# "role": "assistant", -# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic." -# } -# } -# ], -# "created": 1691429984.3852863, -# "model": "claude-instant-1", -# "usage": { -# "prompt_tokens": 18, -# "completion_tokens": 23, -# "total_tokens": 41 -# } -# }""" -# sample_model_response_object = { -# "choices": [ -# { -# "finish_reason": "stop", -# "index": 0, -# "message": { -# "role": "assistant", -# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic." -# } -# } -# ], -# "created": 1691429984.3852863, -# "model": "claude-instant-1", -# "usage": { -# "prompt_tokens": 18, -# "completion_tokens": 23, -# "total_tokens": 41 -# } -# } -# cache.add_cache(cache_key="test_key", result=sample_model_response_object_str, ttl=1) -# cached_value = cache.get_cache(cache_key="test_key") -# print(f"cached-value: {cached_value}") -# assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content'] -# time.sleep(2) -# assert cache.get_cache(cache_key="test_key") is None - -# # test_redis_cache_with_ttl() - -# def test_in_memory_cache_with_ttl(): -# cache = Cache(type="local") -# sample_model_response_object_str = """{ -# "choices": [ -# { -# "finish_reason": "stop", -# "index": 0, -# "message": { -# "role": "assistant", -# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic." -# } -# } -# ], -# "created": 1691429984.3852863, -# "model": "claude-instant-1", -# "usage": { -# "prompt_tokens": 18, -# "completion_tokens": 23, -# "total_tokens": 41 -# } -# }""" -# sample_model_response_object = { -# "choices": [ -# { -# "finish_reason": "stop", -# "index": 0, -# "message": { -# "role": "assistant", -# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic." -# } -# } -# ], -# "created": 1691429984.3852863, -# "model": "claude-instant-1", -# "usage": { -# "prompt_tokens": 18, -# "completion_tokens": 23, -# "total_tokens": 41 -# } -# } -# cache.add_cache(cache_key="test_key", result=sample_model_response_object_str, ttl=1) -# cached_value = cache.get_cache(cache_key="test_key") -# assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content'] -# time.sleep(2) -# assert cache.get_cache(cache_key="test_key") is None -# # test_in_memory_cache_with_ttl() diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 9f36df50b..093bf0f91 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1994,6 +1994,7 @@ def test_completion_palm_stream(): def test_completion_together_ai_stream(): + litellm.set_verbose = True user_message = "Write 1pg about YC & litellm" messages = [{"content": user_message, "role": "user"}] try: diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 266303df1..080754ca8 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -556,7 +556,6 @@ async def test_async_chat_bedrock_stream(): # asyncio.run(test_async_chat_bedrock_stream()) - ## Test Sagemaker + Async @pytest.mark.asyncio async def test_async_chat_sagemaker_stream(): @@ -725,7 +724,7 @@ async def test_async_embedding_bedrock(): response = await litellm.aembedding( model="bedrock/cohere.embed-multilingual-v3", input=["good morning from litellm"], - aws_region_name="os.environ/AWS_REGION_NAME_2", + aws_region_name="us-east-1", ) await asyncio.sleep(1) print(f"customHandler_success.errors: {customHandler_success.errors}") @@ -758,6 +757,7 @@ async def test_async_embedding_bedrock(): ## Test Azure - completion, embedding @pytest.mark.asyncio async def test_async_completion_azure_caching(): + litellm.set_verbose = True customHandler_caching = CompletionCustomHandler() litellm.cache = Cache( type="redis", @@ -812,6 +812,7 @@ async def test_async_embedding_azure_caching(): ) await asyncio.sleep(1) # success callbacks are done in parallel print(customHandler_caching.states) + print(customHandler_caching.errors) assert len(customHandler_caching.errors) == 0 assert len(customHandler_caching.states) == 4 # pre, post, success, success diff --git a/litellm/utils.py b/litellm/utils.py index fe899388f..bcc5dc7df 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -55,6 +55,7 @@ from .integrations.litedebugger import LiteDebugger from .proxy._types import KeyManagementSystem from openai import OpenAIError as OriginalError from openai._models import BaseModel as OpenAIObject +from .caching import S3Cache from .exceptions import ( AuthenticationError, BadRequestError, @@ -862,6 +863,7 @@ class Logging: curl_command += additional_args.get("request_str", None) elif api_base == "": curl_command = self.model_call_details + print_verbose(f"\033[92m{curl_command}\033[0m\n") verbose_logger.info(f"\033[92m{curl_command}\033[0m\n") if self.logger_fn and callable(self.logger_fn): try: @@ -2196,12 +2198,21 @@ def client(original_function): ) # if caching is false or cache["no-cache"]==True, don't run this if ( - (kwargs.get("caching", None) is None and litellm.cache is not None) - or kwargs.get("caching", False) == True - or ( - kwargs.get("cache", None) is not None - and kwargs.get("cache", {}).get("no-cache", False) != True + ( + ( + kwargs.get("caching", None) is None + and kwargs.get("cache", None) is None + and litellm.cache is not None + ) + or kwargs.get("caching", False) == True + or ( + kwargs.get("cache", None) is not None + and kwargs.get("cache", {}).get("no-cache", False) != True + ) ) + and kwargs.get("aembedding", False) != True + and kwargs.get("acompletion", False) != True + and kwargs.get("aimg_generation", False) != True ): # allow users to control returning cached responses from the completion function # checking cache print_verbose(f"INSIDE CHECKING CACHE") @@ -2435,6 +2446,7 @@ def client(original_function): result = None logging_obj = kwargs.get("litellm_logging_obj", None) # only set litellm_call_id if its not in kwargs + call_type = original_function.__name__ if "litellm_call_id" not in kwargs: kwargs["litellm_call_id"] = str(uuid.uuid4()) try: @@ -2465,8 +2477,14 @@ def client(original_function): f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}" ) # if caching is false, don't run this + final_embedding_cached_response = None + if ( - (kwargs.get("caching", None) is None and litellm.cache is not None) + ( + kwargs.get("caching", None) is None + and kwargs.get("cache", None) is None + and litellm.cache is not None + ) or kwargs.get("caching", False) == True or ( kwargs.get("cache", None) is not None @@ -2481,8 +2499,36 @@ def client(original_function): in litellm.cache.supported_call_types ): print_verbose(f"Checking Cache") - cached_result = litellm.cache.get_cache(*args, **kwargs) - if cached_result != None: + if call_type == CallTypes.aembedding.value and isinstance( + kwargs["input"], list + ): + tasks = [] + for idx, i in enumerate(kwargs["input"]): + preset_cache_key = litellm.cache.get_cache_key( + *args, **{**kwargs, "input": i} + ) + tasks.append( + litellm.cache.async_get_cache( + cache_key=preset_cache_key + ) + ) + cached_result = await asyncio.gather(*tasks) + ## check if cached result is None ## + if cached_result is not None and isinstance( + cached_result, list + ): + if len(cached_result) == 1 and cached_result[0] is None: + cached_result = None + else: + preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) + kwargs[ + "preset_cache_key" + ] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key + cached_result = litellm.cache.get_cache(*args, **kwargs) + + if cached_result is not None and not isinstance( + cached_result, list + ): print_verbose(f"Cache Hit!") call_type = original_function.__name__ if call_type == CallTypes.acompletion.value and isinstance( @@ -2555,6 +2601,103 @@ def client(original_function): args=(cached_result, start_time, end_time, cache_hit), ).start() return cached_result + elif ( + call_type == CallTypes.aembedding.value + and cached_result is not None + and isinstance(cached_result, list) + and litellm.cache is not None + and not isinstance( + litellm.cache.cache, S3Cache + ) # s3 doesn't support bulk writing. Exclude. + ): + remaining_list = [] + non_null_list = [] + for idx, cr in enumerate(cached_result): + if cr is None: + remaining_list.append(kwargs["input"][idx]) + else: + non_null_list.append((idx, cr)) + original_kwargs_input = kwargs["input"] + kwargs["input"] = remaining_list + if len(non_null_list) > 0: + print_verbose( + f"EMBEDDING CACHE HIT! - {len(non_null_list)}" + ) + final_embedding_cached_response = EmbeddingResponse( + model=kwargs.get("model"), + data=[None] * len(original_kwargs_input), + ) + final_embedding_cached_response._hidden_params[ + "cache_hit" + ] = True + + for val in non_null_list: + idx, cr = val # (idx, cr) tuple + if cr is not None: + final_embedding_cached_response.data[idx] = cr + if len(remaining_list) == 0: + # LOG SUCCESS + cache_hit = True + end_time = datetime.datetime.now() + ( + model, + custom_llm_provider, + dynamic_api_key, + api_base, + ) = litellm.get_llm_provider( + model=model, + custom_llm_provider=kwargs.get( + "custom_llm_provider", None + ), + api_base=kwargs.get("api_base", None), + api_key=kwargs.get("api_key", None), + ) + print_verbose( + f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" + ) + logging_obj.update_environment_variables( + model=model, + user=kwargs.get("user", None), + optional_params={}, + litellm_params={ + "logger_fn": kwargs.get("logger_fn", None), + "acompletion": True, + "metadata": kwargs.get("metadata", {}), + "model_info": kwargs.get("model_info", {}), + "proxy_server_request": kwargs.get( + "proxy_server_request", None + ), + "preset_cache_key": kwargs.get( + "preset_cache_key", None + ), + "stream_response": kwargs.get( + "stream_response", {} + ), + }, + input=kwargs.get("messages", ""), + api_key=kwargs.get("api_key", None), + original_response=str(final_embedding_cached_response), + additional_args=None, + stream=kwargs.get("stream", False), + ) + asyncio.create_task( + logging_obj.async_success_handler( + final_embedding_cached_response, + start_time, + end_time, + cache_hit, + ) + ) + threading.Thread( + target=logging_obj.success_handler, + args=( + final_embedding_cached_response, + start_time, + end_time, + cache_hit, + ), + ).start() + return final_embedding_cached_response # MODEL CALL result = await original_function(*args, **kwargs) end_time = datetime.datetime.now() @@ -2587,12 +2730,28 @@ def client(original_function): if isinstance(result, litellm.ModelResponse) or isinstance( result, litellm.EmbeddingResponse ): - asyncio.create_task( - litellm.cache._async_add_cache(result.json(), *args, **kwargs) - ) + if ( + isinstance(result, EmbeddingResponse) + and isinstance(kwargs["input"], list) + and litellm.cache is not None + and not isinstance( + litellm.cache.cache, S3Cache + ) # s3 doesn't support bulk writing. Exclude. + ): + asyncio.create_task( + litellm.cache.async_add_cache_pipeline( + result, *args, **kwargs + ) + ) + else: + asyncio.create_task( + litellm.cache.async_add_cache( + result.json(), *args, **kwargs + ) + ) else: asyncio.create_task( - litellm.cache._async_add_cache(result, *args, **kwargs) + litellm.cache.async_add_cache(result, *args, **kwargs) ) # LOG SUCCESS - handle streaming success logging in the _next_ object print_verbose( @@ -2616,6 +2775,27 @@ def client(original_function): result._response_ms = ( end_time - start_time ).total_seconds() * 1000 # return response latency in ms like openai + + if ( + isinstance(result, EmbeddingResponse) + and final_embedding_cached_response is not None + ): + idx = 0 + final_data_list = [] + for item in final_embedding_cached_response.data: + if item is None: + final_data_list.append(result.data[idx]) + idx += 1 + else: + final_data_list.append(item) + + final_embedding_cached_response.data = final_data_list + final_embedding_cached_response._hidden_params["cache_hit"] = True + final_embedding_cached_response._response_ms = ( + end_time - start_time + ).total_seconds() * 1000 + return final_embedding_cached_response + return result except Exception as e: traceback_exception = traceback.format_exc() @@ -3275,7 +3455,11 @@ def completion_cost( else: raise Exception(f"Model={model} not found in completion cost model map") # Calculate cost based on prompt_tokens, completion_tokens - if "togethercomputer" in model or "together_ai" in model: + if ( + "togethercomputer" in model + or "together_ai" in model + or custom_llm_provider == "together_ai" + ): # together ai prices based on size of llm # get_model_params_and_category takes a model name and returns the category of LLM size it is in model_prices_and_context_window.json model = get_model_params_and_category(model) @@ -3864,7 +4048,7 @@ def get_optional_params( _check_valid_arg(supported_params=supported_params) if stream: - optional_params["stream_tokens"] = stream + optional_params["stream"] = stream if temperature is not None: optional_params["temperature"] = temperature if top_p is not None: @@ -4498,6 +4682,14 @@ def get_llm_provider( # voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1 api_base = "https://api.voyageai.com/v1" dynamic_api_key = get_secret("VOYAGE_API_KEY") + elif custom_llm_provider == "together_ai": + api_base = "https://api.together.xyz/v1" + dynamic_api_key = ( + get_secret("TOGETHER_API_KEY") + or get_secret("TOGETHER_AI_API_KEY") + or get_secret("TOGETHERAI_API_KEY") + or get_secret("TOGETHER_AI_TOKEN") + ) return model, custom_llm_provider, dynamic_api_key, api_base elif model.split("/", 1)[0] in litellm.provider_list: custom_llm_provider = model.split("/", 1)[0] @@ -6383,7 +6575,12 @@ def exception_type( message=f"BedrockException - {original_exception.message}", llm_provider="bedrock", model=model, - response=original_exception.response, + response=httpx.Response( + status_code=500, + request=httpx.Request( + method="POST", url="https://api.openai.com/v1/" + ), + ), ) elif original_exception.status_code == 401: exception_mapping_worked = True