forked from phoenix/litellm-mirror
Merge pull request #1809 from BerriAI/litellm_embedding_caching_updates
Support caching individual items in embedding list (Async embedding only)
This commit is contained in:
commit
28df60b609
13 changed files with 638 additions and 196 deletions
|
@ -285,6 +285,7 @@ openai_compatible_endpoints: List = [
|
|||
"api.endpoints.anyscale.com/v1",
|
||||
"api.deepinfra.com/v1/openai",
|
||||
"api.mistral.ai/v1",
|
||||
"api.together.xyz/v1",
|
||||
]
|
||||
|
||||
# this is maintained for Exception Mapping
|
||||
|
@ -294,6 +295,7 @@ openai_compatible_providers: List = [
|
|||
"deepinfra",
|
||||
"perplexity",
|
||||
"xinference",
|
||||
"together_ai",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
import os
|
||||
import inspect
|
||||
import redis, litellm
|
||||
import redis.asyncio as async_redis
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
|
@ -67,7 +68,10 @@ def get_redis_url_from_environment():
|
|||
)
|
||||
|
||||
|
||||
def get_redis_client(**env_overrides):
|
||||
def _get_redis_client_logic(**env_overrides):
|
||||
"""
|
||||
Common functionality across sync + async redis client implementations
|
||||
"""
|
||||
### check if "os.environ/<key-name>" passed in
|
||||
for k, v in env_overrides.items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
|
@ -85,9 +89,33 @@ def get_redis_client(**env_overrides):
|
|||
redis_kwargs.pop("port", None)
|
||||
redis_kwargs.pop("db", None)
|
||||
redis_kwargs.pop("password", None)
|
||||
|
||||
return redis.Redis.from_url(**redis_kwargs)
|
||||
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
|
||||
raise ValueError("Either 'host' or 'url' must be specified for redis.")
|
||||
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
|
||||
return redis_kwargs
|
||||
|
||||
|
||||
def get_redis_client(**env_overrides):
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
return redis.Redis.from_url(**redis_kwargs)
|
||||
return redis.Redis(**redis_kwargs)
|
||||
|
||||
|
||||
def get_redis_async_client(**env_overrides):
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
return async_redis.Redis.from_url(**redis_kwargs)
|
||||
return async_redis.Redis(
|
||||
socket_timeout=5,
|
||||
**redis_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def get_redis_connection_pool(**env_overrides):
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
return async_redis.BlockingConnectionPool.from_url(
|
||||
timeout=5, url=redis_kwargs["url"]
|
||||
)
|
||||
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import litellm
|
||||
import time, logging
|
||||
import time, logging, asyncio
|
||||
import json, traceback, ast, hashlib
|
||||
from typing import Optional, Literal, List, Union, Any
|
||||
from openai._models import BaseModel as OpenAIObject
|
||||
|
@ -28,9 +28,18 @@ class BaseCache:
|
|||
def set_cache(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def disconnect(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
def __init__(self):
|
||||
|
@ -43,6 +52,16 @@ class InMemoryCache(BaseCache):
|
|||
if "ttl" in kwargs:
|
||||
self.ttl_dict[key] = time.time() + kwargs["ttl"]
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
self.set_cache(key=key, value=value, **kwargs)
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, ttl=None):
|
||||
for cache_key, cache_value in cache_list:
|
||||
if ttl is not None:
|
||||
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
|
||||
else:
|
||||
self.set_cache(key=cache_key, value=cache_value)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
if key in self.cache_dict:
|
||||
if key in self.ttl_dict:
|
||||
|
@ -57,21 +76,27 @@ class InMemoryCache(BaseCache):
|
|||
return cached_response
|
||||
return None
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
return self.get_cache(key=key, **kwargs)
|
||||
|
||||
def flush_cache(self):
|
||||
self.cache_dict.clear()
|
||||
self.ttl_dict.clear()
|
||||
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
def delete_cache(self, key):
|
||||
self.cache_dict.pop(key, None)
|
||||
self.ttl_dict.pop(key, None)
|
||||
|
||||
|
||||
class RedisCache(BaseCache):
|
||||
def __init__(self, host=None, port=None, password=None, **kwargs):
|
||||
import redis
|
||||
|
||||
# if users don't provider one, use the default litellm cache
|
||||
from ._redis import get_redis_client
|
||||
|
||||
def __init__(self, host=None, port=None, password=None, **kwargs):
|
||||
from ._redis import get_redis_client, get_redis_connection_pool
|
||||
|
||||
redis_kwargs = {}
|
||||
if host is not None:
|
||||
|
@ -82,18 +107,84 @@ class RedisCache(BaseCache):
|
|||
redis_kwargs["password"] = password
|
||||
|
||||
redis_kwargs.update(kwargs)
|
||||
|
||||
self.redis_client = get_redis_client(**redis_kwargs)
|
||||
self.redis_kwargs = redis_kwargs
|
||||
self.async_redis_conn_pool = get_redis_connection_pool()
|
||||
|
||||
def init_async_client(self):
|
||||
from ._redis import get_redis_async_client
|
||||
|
||||
return get_redis_async_client(
|
||||
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
|
||||
)
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
ttl = kwargs.get("ttl", None)
|
||||
print_verbose(f"Set Redis Cache: key: {key}\nValue {value}")
|
||||
print_verbose(f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
|
||||
try:
|
||||
self.redis_client.set(name=key, value=str(value), ex=ttl)
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
_redis_client = self.init_async_client()
|
||||
async with _redis_client as redis_client:
|
||||
ttl = kwargs.get("ttl", None)
|
||||
print_verbose(
|
||||
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
||||
)
|
||||
try:
|
||||
await redis_client.set(name=key, value=json.dumps(value), ex=ttl)
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, ttl=None):
|
||||
"""
|
||||
Use Redis Pipelines for bulk write operations
|
||||
"""
|
||||
_redis_client = self.init_async_client()
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
async with redis_client.pipeline(transaction=True) as pipe:
|
||||
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
|
||||
for cache_key, cache_value in cache_list:
|
||||
print_verbose(
|
||||
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
|
||||
)
|
||||
# Set the value with a TTL if it's provided.
|
||||
if ttl is not None:
|
||||
pipe.setex(cache_key, ttl, json.dumps(cache_value))
|
||||
else:
|
||||
pipe.set(cache_key, json.dumps(cache_value))
|
||||
# Execute the pipeline and return the results.
|
||||
results = await pipe.execute()
|
||||
|
||||
print_verbose(f"pipeline results: {results}")
|
||||
# Optionally, you could process 'results' to make sure that all set operations were successful.
|
||||
return results
|
||||
except Exception as e:
|
||||
print_verbose(f"Error occurred in pipeline write - {str(e)}")
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||
|
||||
def _get_cache_logic(self, cached_response: Any):
|
||||
"""
|
||||
Common 'get_cache_logic' across sync + async redis client implementations
|
||||
"""
|
||||
if cached_response is None:
|
||||
return cached_response
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = cached_response.decode("utf-8") # Convert bytes to string
|
||||
try:
|
||||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
return cached_response
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
try:
|
||||
print_verbose(f"Get Redis Cache: key: {key}")
|
||||
|
@ -101,18 +192,23 @@ class RedisCache(BaseCache):
|
|||
print_verbose(
|
||||
f"Got Redis Cache: key: {key}, cached_response {cached_response}"
|
||||
)
|
||||
if cached_response != None:
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = cached_response.decode(
|
||||
"utf-8"
|
||||
) # Convert bytes to string
|
||||
return self._get_cache_logic(cached_response=cached_response)
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
traceback.print_exc()
|
||||
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
_redis_client = self.init_async_client()
|
||||
async with _redis_client as redis_client:
|
||||
try:
|
||||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
return cached_response
|
||||
print_verbose(f"Get Redis Cache: key: {key}")
|
||||
cached_response = await redis_client.get(key)
|
||||
print_verbose(
|
||||
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
|
||||
)
|
||||
response = self._get_cache_logic(cached_response=cached_response)
|
||||
return response
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
traceback.print_exc()
|
||||
|
@ -121,10 +217,15 @@ class RedisCache(BaseCache):
|
|||
def flush_cache(self):
|
||||
self.redis_client.flushall()
|
||||
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
def delete_cache(self, key):
|
||||
self.redis_client.delete(key)
|
||||
|
||||
|
||||
|
||||
class S3Cache(BaseCache):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -202,6 +303,9 @@ class S3Cache(BaseCache):
|
|||
# NON blocking - notify users S3 is throwing an exception
|
||||
print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
self.set_cache(key=key, value=value, **kwargs)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
import boto3, botocore
|
||||
|
||||
|
@ -244,6 +348,9 @@ class S3Cache(BaseCache):
|
|||
traceback.print_exc()
|
||||
print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}")
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
return self.get_cache(key=key, **kwargs)
|
||||
|
||||
def flush_cache(self):
|
||||
pass
|
||||
|
||||
|
@ -361,9 +468,9 @@ class Cache:
|
|||
"""
|
||||
if type == "redis":
|
||||
self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
|
||||
if type == "local":
|
||||
elif type == "local":
|
||||
self.cache = InMemoryCache()
|
||||
if type == "s3":
|
||||
elif type == "s3":
|
||||
self.cache = S3Cache(
|
||||
s3_bucket_name=s3_bucket_name,
|
||||
s3_region_name=s3_region_name,
|
||||
|
@ -489,6 +596,45 @@ class Cache:
|
|||
}
|
||||
time.sleep(0.02)
|
||||
|
||||
def _get_cache_logic(
|
||||
self,
|
||||
cached_result: Optional[Any],
|
||||
max_age: Optional[float],
|
||||
):
|
||||
"""
|
||||
Common get cache logic across sync + async implementations
|
||||
"""
|
||||
# Check if a timestamp was stored with the cached response
|
||||
if (
|
||||
cached_result is not None
|
||||
and isinstance(cached_result, dict)
|
||||
and "timestamp" in cached_result
|
||||
):
|
||||
timestamp = cached_result["timestamp"]
|
||||
current_time = time.time()
|
||||
|
||||
# Calculate age of the cached response
|
||||
response_age = current_time - timestamp
|
||||
|
||||
# Check if the cached response is older than the max-age
|
||||
if max_age is not None and response_age > max_age:
|
||||
return None # Cached response is too old
|
||||
|
||||
# If the response is fresh, or there's no max-age requirement, return the cached response
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = cached_result.get("response")
|
||||
try:
|
||||
if isinstance(cached_response, dict):
|
||||
pass
|
||||
else:
|
||||
cached_response = json.loads(
|
||||
cached_response # type: ignore
|
||||
) # Convert string to dictionary
|
||||
except:
|
||||
cached_response = ast.literal_eval(cached_response) # type: ignore
|
||||
return cached_response
|
||||
return cached_result
|
||||
|
||||
def get_cache(self, *args, **kwargs):
|
||||
"""
|
||||
Retrieves the cached result for the given arguments.
|
||||
|
@ -511,54 +657,40 @@ class Cache:
|
|||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||
)
|
||||
cached_result = self.cache.get_cache(cache_key)
|
||||
# Check if a timestamp was stored with the cached response
|
||||
if (
|
||||
cached_result is not None
|
||||
and isinstance(cached_result, dict)
|
||||
and "timestamp" in cached_result
|
||||
and max_age is not None
|
||||
):
|
||||
timestamp = cached_result["timestamp"]
|
||||
current_time = time.time()
|
||||
|
||||
# Calculate age of the cached response
|
||||
response_age = current_time - timestamp
|
||||
|
||||
# Check if the cached response is older than the max-age
|
||||
if response_age > max_age:
|
||||
print_verbose(
|
||||
f"Cached response for key {cache_key} is too old. Max-age: {max_age}s, Age: {response_age}s"
|
||||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
return None # Cached response is too old
|
||||
|
||||
# If the response is fresh, or there's no max-age requirement, return the cached response
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = cached_result.get("response")
|
||||
try:
|
||||
if isinstance(cached_response, dict):
|
||||
pass
|
||||
else:
|
||||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
return cached_response
|
||||
return cached_result
|
||||
except Exception as e:
|
||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
def add_cache(self, result, *args, **kwargs):
|
||||
async def async_get_cache(self, *args, **kwargs):
|
||||
"""
|
||||
Adds a result to the cache.
|
||||
Async get cache implementation.
|
||||
|
||||
Args:
|
||||
*args: args to litellm.completion() or embedding()
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
Used for embedding calls in async wrapper
|
||||
"""
|
||||
try: # never block execution
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(*args, **kwargs)
|
||||
if cache_key is not None:
|
||||
cache_control_args = kwargs.get("cache", {})
|
||||
max_age = cache_control_args.get(
|
||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||
)
|
||||
cached_result = await self.cache.async_get_cache(cache_key)
|
||||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
Returns:
|
||||
None
|
||||
def _add_cache_logic(self, result, *args, **kwargs):
|
||||
"""
|
||||
Common implementation across sync + async add_cache functions
|
||||
"""
|
||||
try:
|
||||
if "cache_key" in kwargs:
|
||||
|
@ -577,14 +709,82 @@ class Cache:
|
|||
if k == "ttl":
|
||||
kwargs["ttl"] = v
|
||||
cached_data = {"timestamp": time.time(), "response": result}
|
||||
return cache_key, cached_data, kwargs
|
||||
else:
|
||||
raise Exception("cache key is None")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def add_cache(self, result, *args, **kwargs):
|
||||
"""
|
||||
Adds a result to the cache.
|
||||
|
||||
Args:
|
||||
*args: args to litellm.completion() or embedding()
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=result, *args, **kwargs
|
||||
)
|
||||
self.cache.set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
pass
|
||||
|
||||
async def _async_add_cache(self, result, *args, **kwargs):
|
||||
self.add_cache(result, *args, **kwargs)
|
||||
async def async_add_cache(self, result, *args, **kwargs):
|
||||
"""
|
||||
Async implementation of add_cache
|
||||
"""
|
||||
try:
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=result, *args, **kwargs
|
||||
)
|
||||
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def async_add_cache_pipeline(self, result, *args, **kwargs):
|
||||
"""
|
||||
Async implementation of add_cache for Embedding calls
|
||||
|
||||
Does a bulk write, to prevent using too many clients
|
||||
"""
|
||||
try:
|
||||
cache_list = []
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
preset_cache_key = litellm.cache.get_cache_key(
|
||||
*args, **{**kwargs, "input": i}
|
||||
)
|
||||
kwargs["cache_key"] = preset_cache_key
|
||||
embedding_response = result.data[idx]
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=embedding_response,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
cache_list.append((cache_key, cached_data))
|
||||
if hasattr(self.cache, "async_set_cache_pipeline"):
|
||||
await self.cache.async_set_cache_pipeline(cache_list=cache_list)
|
||||
else:
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(
|
||||
self.cache.async_set_cache(cache_key, cached_data, **kwargs)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def disconnect(self):
|
||||
if hasattr(self.cache, "disconnect"):
|
||||
await self.cache.disconnect()
|
||||
|
||||
|
||||
def enable_cache(
|
||||
|
|
|
@ -440,8 +440,8 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
input=data["messages"],
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
"headers": {"Authorization": f"Bearer {openai_client.api_key}"},
|
||||
"api_base": openai_client._base_url._uri_reference,
|
||||
"acompletion": False,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
"""
|
||||
Deprecated. We now do together ai calls via the openai client.
|
||||
Reference: https://docs.together.ai/docs/openai-api-compatibility
|
||||
"""
|
||||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
import os, openai, sys, json, inspect, uuid, datetime, threading
|
||||
from typing import Any, Literal, Union
|
||||
from functools import partial
|
||||
|
||||
import dotenv, traceback, random, asyncio, time, contextvars
|
||||
from copy import deepcopy
|
||||
import httpx
|
||||
|
@ -234,6 +235,9 @@ async def acompletion(
|
|||
"model_list": model_list,
|
||||
"acompletion": True, # assuming this is a required parameter
|
||||
}
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model=model, api_base=completion_kwargs.get("base_url", None)
|
||||
)
|
||||
try:
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(completion, **completion_kwargs, **kwargs)
|
||||
|
@ -245,7 +249,6 @@ async def acompletion(
|
|||
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model=model, api_base=kwargs.get("api_base", None)
|
||||
)
|
||||
|
||||
if (
|
||||
custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "azure"
|
||||
|
@ -788,6 +791,7 @@ def completion(
|
|||
or custom_llm_provider == "anyscale"
|
||||
or custom_llm_provider == "mistral"
|
||||
or custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "together_ai"
|
||||
or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo
|
||||
): # allow user to make an openai call with a custom base
|
||||
# note: if a user sets a custom base - we should ensure this works
|
||||
|
@ -1327,6 +1331,9 @@ def completion(
|
|||
or ("togethercomputer" in model)
|
||||
or (model in litellm.together_ai_models)
|
||||
):
|
||||
"""
|
||||
Deprecated. We now do together ai calls via the openai client - https://docs.together.ai/docs/openai-api-compatibility
|
||||
"""
|
||||
custom_llm_provider = "together_ai"
|
||||
together_ai_key = (
|
||||
api_key
|
||||
|
|
|
@ -380,7 +380,7 @@ def run_server(
|
|||
import gunicorn.app.base
|
||||
except:
|
||||
raise ImportError(
|
||||
"Uvicorn, gunicorn needs to be imported. Run - `pip 'litellm[proxy]'`"
|
||||
"uvicorn, gunicorn needs to be imported. Run - `pip install 'litellm[proxy]'`"
|
||||
)
|
||||
|
||||
if config is not None:
|
||||
|
@ -444,6 +444,7 @@ def run_server(
|
|||
)
|
||||
if port == 8000 and is_port_in_use(port):
|
||||
port = random.randint(1024, 49152)
|
||||
|
||||
from litellm.proxy.proxy_server import app
|
||||
|
||||
if run_gunicorn == False:
|
||||
|
@ -521,5 +522,6 @@ def run_server(
|
|||
).run() # Run gunicorn
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_server()
|
||||
|
|
|
@ -7,6 +7,20 @@ import secrets, subprocess
|
|||
import hashlib, uuid
|
||||
import warnings
|
||||
import importlib
|
||||
import warnings
|
||||
|
||||
|
||||
def showwarning(message, category, filename, lineno, file=None, line=None):
|
||||
traceback_info = f"{filename}:{lineno}: {category.__name__}: {message}\n"
|
||||
if file is not None:
|
||||
file.write(traceback_info)
|
||||
|
||||
|
||||
warnings.showwarning = showwarning
|
||||
warnings.filterwarnings("default", category=UserWarning)
|
||||
|
||||
# Your client code here
|
||||
|
||||
|
||||
messages: list = []
|
||||
sys.path.insert(
|
||||
|
@ -4053,9 +4067,12 @@ def _has_user_setup_sso():
|
|||
async def shutdown_event():
|
||||
global prisma_client, master_key, user_custom_auth, user_custom_key_generate
|
||||
if prisma_client:
|
||||
|
||||
verbose_proxy_logger.debug("Disconnecting from Prisma")
|
||||
await prisma_client.disconnect()
|
||||
|
||||
if litellm.cache is not None:
|
||||
await litellm.cache.disconnect()
|
||||
## RESET CUSTOM VARIABLES ##
|
||||
cleanup_router_config_variables()
|
||||
|
||||
|
|
|
@ -21,10 +21,18 @@ def setup_and_teardown():
|
|||
import litellm
|
||||
|
||||
importlib.reload(litellm)
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
print(litellm)
|
||||
# from litellm import Router, completion, aembedding, acompletion, embedding
|
||||
yield
|
||||
|
||||
# Teardown code (executes after the yield point)
|
||||
loop.close() # Close the loop created earlier
|
||||
asyncio.set_event_loop(None) # Remove the reference to the loop
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
|
||||
|
|
|
@ -11,10 +11,10 @@ sys.path.insert(
|
|||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm import embedding, completion
|
||||
from litellm import embedding, completion, aembedding
|
||||
from litellm.caching import Cache
|
||||
import random
|
||||
import hashlib
|
||||
import hashlib, asyncio
|
||||
|
||||
# litellm.set_verbose=True
|
||||
|
||||
|
@ -106,10 +106,7 @@ def test_caching_with_cache_controls():
|
|||
)
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
assert (
|
||||
response2["choices"][0]["message"]["content"]
|
||||
== response1["choices"][0]["message"]["content"]
|
||||
)
|
||||
assert response2["id"] == response1["id"]
|
||||
except Exception as e:
|
||||
print(f"error occurred: {traceback.format_exc()}")
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
@ -259,6 +256,84 @@ def test_embedding_caching_azure():
|
|||
# test_embedding_caching_azure()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_caching_azure_individual_items():
|
||||
"""
|
||||
Tests caching for individual items in an embedding list
|
||||
|
||||
- Cache an item
|
||||
- call aembedding(..) with the item + 1 unique item
|
||||
- compare to a 2nd aembedding (...) with 2 unique items
|
||||
```
|
||||
embedding_1 = ["hey how's it going", "I'm doing well"]
|
||||
embedding_val_1 = embedding(...)
|
||||
|
||||
embedding_2 = ["hey how's it going", "I'm fine"]
|
||||
embedding_val_2 = embedding(...)
|
||||
|
||||
assert embedding_val_1[0]["id"] == embedding_val_2[0]["id"]
|
||||
```
|
||||
"""
|
||||
litellm.cache = Cache()
|
||||
common_msg = f"hey how's it going {uuid.uuid4()}"
|
||||
common_msg_2 = f"hey how's it going {uuid.uuid4()}"
|
||||
embedding_1 = [common_msg]
|
||||
embedding_2 = [
|
||||
common_msg,
|
||||
f"I'm fine {uuid.uuid4()}",
|
||||
]
|
||||
|
||||
embedding_val_1 = await aembedding(
|
||||
model="azure/azure-embedding-model", input=embedding_1, caching=True
|
||||
)
|
||||
embedding_val_2 = await aembedding(
|
||||
model="azure/azure-embedding-model", input=embedding_2, caching=True
|
||||
)
|
||||
print(f"embedding_val_2._hidden_params: {embedding_val_2._hidden_params}")
|
||||
assert embedding_val_2._hidden_params["cache_hit"] == True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_cache_basic():
|
||||
"""
|
||||
Init redis client
|
||||
- write to client
|
||||
- read from client
|
||||
"""
|
||||
litellm.set_verbose = False
|
||||
|
||||
random_number = random.randint(
|
||||
1, 100000
|
||||
) # add a random number to ensure it's always adding / reading from cache
|
||||
messages = [
|
||||
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
|
||||
]
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
host=os.environ["REDIS_HOST"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
password=os.environ["REDIS_PASSWORD"],
|
||||
)
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
cache_key = litellm.cache.get_cache_key(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
)
|
||||
print(f"cache_key: {cache_key}")
|
||||
litellm.cache.add_cache(result=response1, cache_key=cache_key)
|
||||
print(f"cache key pre async get: {cache_key}")
|
||||
stored_val = await litellm.cache.async_get_cache(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
)
|
||||
print(f"stored_val: {stored_val}")
|
||||
assert stored_val["id"] == response1.id
|
||||
|
||||
|
||||
def test_redis_cache_completion():
|
||||
litellm.set_verbose = False
|
||||
|
||||
|
@ -406,7 +481,7 @@ def test_redis_cache_acompletion_stream():
|
|||
import asyncio
|
||||
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
litellm.set_verbose = False
|
||||
random_word = generate_random_word()
|
||||
messages = [
|
||||
{
|
||||
|
@ -434,7 +509,6 @@ def test_redis_cache_acompletion_stream():
|
|||
stream=True,
|
||||
)
|
||||
async for chunk in response1:
|
||||
print(chunk)
|
||||
response_1_content += chunk.choices[0].delta.content or ""
|
||||
print(response_1_content)
|
||||
|
||||
|
@ -452,7 +526,6 @@ def test_redis_cache_acompletion_stream():
|
|||
stream=True,
|
||||
)
|
||||
async for chunk in response2:
|
||||
print(chunk)
|
||||
response_2_content += chunk.choices[0].delta.content or ""
|
||||
print(response_2_content)
|
||||
|
||||
|
@ -914,101 +987,3 @@ def test_cache_context_managers():
|
|||
|
||||
|
||||
# test_cache_context_managers()
|
||||
|
||||
# test_custom_redis_cache_params()
|
||||
|
||||
# def test_redis_cache_with_ttl():
|
||||
# cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
# sample_model_response_object_str = """{
|
||||
# "choices": [
|
||||
# {
|
||||
# "finish_reason": "stop",
|
||||
# "index": 0,
|
||||
# "message": {
|
||||
# "role": "assistant",
|
||||
# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
|
||||
# }
|
||||
# }
|
||||
# ],
|
||||
# "created": 1691429984.3852863,
|
||||
# "model": "claude-instant-1",
|
||||
# "usage": {
|
||||
# "prompt_tokens": 18,
|
||||
# "completion_tokens": 23,
|
||||
# "total_tokens": 41
|
||||
# }
|
||||
# }"""
|
||||
# sample_model_response_object = {
|
||||
# "choices": [
|
||||
# {
|
||||
# "finish_reason": "stop",
|
||||
# "index": 0,
|
||||
# "message": {
|
||||
# "role": "assistant",
|
||||
# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
|
||||
# }
|
||||
# }
|
||||
# ],
|
||||
# "created": 1691429984.3852863,
|
||||
# "model": "claude-instant-1",
|
||||
# "usage": {
|
||||
# "prompt_tokens": 18,
|
||||
# "completion_tokens": 23,
|
||||
# "total_tokens": 41
|
||||
# }
|
||||
# }
|
||||
# cache.add_cache(cache_key="test_key", result=sample_model_response_object_str, ttl=1)
|
||||
# cached_value = cache.get_cache(cache_key="test_key")
|
||||
# print(f"cached-value: {cached_value}")
|
||||
# assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content']
|
||||
# time.sleep(2)
|
||||
# assert cache.get_cache(cache_key="test_key") is None
|
||||
|
||||
# # test_redis_cache_with_ttl()
|
||||
|
||||
# def test_in_memory_cache_with_ttl():
|
||||
# cache = Cache(type="local")
|
||||
# sample_model_response_object_str = """{
|
||||
# "choices": [
|
||||
# {
|
||||
# "finish_reason": "stop",
|
||||
# "index": 0,
|
||||
# "message": {
|
||||
# "role": "assistant",
|
||||
# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
|
||||
# }
|
||||
# }
|
||||
# ],
|
||||
# "created": 1691429984.3852863,
|
||||
# "model": "claude-instant-1",
|
||||
# "usage": {
|
||||
# "prompt_tokens": 18,
|
||||
# "completion_tokens": 23,
|
||||
# "total_tokens": 41
|
||||
# }
|
||||
# }"""
|
||||
# sample_model_response_object = {
|
||||
# "choices": [
|
||||
# {
|
||||
# "finish_reason": "stop",
|
||||
# "index": 0,
|
||||
# "message": {
|
||||
# "role": "assistant",
|
||||
# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
|
||||
# }
|
||||
# }
|
||||
# ],
|
||||
# "created": 1691429984.3852863,
|
||||
# "model": "claude-instant-1",
|
||||
# "usage": {
|
||||
# "prompt_tokens": 18,
|
||||
# "completion_tokens": 23,
|
||||
# "total_tokens": 41
|
||||
# }
|
||||
# }
|
||||
# cache.add_cache(cache_key="test_key", result=sample_model_response_object_str, ttl=1)
|
||||
# cached_value = cache.get_cache(cache_key="test_key")
|
||||
# assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content']
|
||||
# time.sleep(2)
|
||||
# assert cache.get_cache(cache_key="test_key") is None
|
||||
# # test_in_memory_cache_with_ttl()
|
||||
|
|
|
@ -1994,6 +1994,7 @@ def test_completion_palm_stream():
|
|||
|
||||
|
||||
def test_completion_together_ai_stream():
|
||||
litellm.set_verbose = True
|
||||
user_message = "Write 1pg about YC & litellm"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
|
|
|
@ -556,7 +556,6 @@ async def test_async_chat_bedrock_stream():
|
|||
|
||||
# asyncio.run(test_async_chat_bedrock_stream())
|
||||
|
||||
|
||||
## Test Sagemaker + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_sagemaker_stream():
|
||||
|
@ -725,7 +724,7 @@ async def test_async_embedding_bedrock():
|
|||
response = await litellm.aembedding(
|
||||
model="bedrock/cohere.embed-multilingual-v3",
|
||||
input=["good morning from litellm"],
|
||||
aws_region_name="os.environ/AWS_REGION_NAME_2",
|
||||
aws_region_name="us-east-1",
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||
|
@ -758,6 +757,7 @@ async def test_async_embedding_bedrock():
|
|||
## Test Azure - completion, embedding
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_completion_azure_caching():
|
||||
litellm.set_verbose = True
|
||||
customHandler_caching = CompletionCustomHandler()
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
|
@ -812,6 +812,7 @@ async def test_async_embedding_azure_caching():
|
|||
)
|
||||
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||
print(customHandler_caching.states)
|
||||
print(customHandler_caching.errors)
|
||||
assert len(customHandler_caching.errors) == 0
|
||||
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
||||
|
||||
|
|
213
litellm/utils.py
213
litellm/utils.py
|
@ -55,6 +55,7 @@ from .integrations.litedebugger import LiteDebugger
|
|||
from .proxy._types import KeyManagementSystem
|
||||
from openai import OpenAIError as OriginalError
|
||||
from openai._models import BaseModel as OpenAIObject
|
||||
from .caching import S3Cache
|
||||
from .exceptions import (
|
||||
AuthenticationError,
|
||||
BadRequestError,
|
||||
|
@ -862,6 +863,7 @@ class Logging:
|
|||
curl_command += additional_args.get("request_str", None)
|
||||
elif api_base == "":
|
||||
curl_command = self.model_call_details
|
||||
print_verbose(f"\033[92m{curl_command}\033[0m\n")
|
||||
verbose_logger.info(f"\033[92m{curl_command}\033[0m\n")
|
||||
if self.logger_fn and callable(self.logger_fn):
|
||||
try:
|
||||
|
@ -2196,12 +2198,21 @@ def client(original_function):
|
|||
)
|
||||
# if caching is false or cache["no-cache"]==True, don't run this
|
||||
if (
|
||||
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
||||
(
|
||||
(
|
||||
kwargs.get("caching", None) is None
|
||||
and kwargs.get("cache", None) is None
|
||||
and litellm.cache is not None
|
||||
)
|
||||
or kwargs.get("caching", False) == True
|
||||
or (
|
||||
kwargs.get("cache", None) is not None
|
||||
and kwargs.get("cache", {}).get("no-cache", False) != True
|
||||
)
|
||||
)
|
||||
and kwargs.get("aembedding", False) != True
|
||||
and kwargs.get("acompletion", False) != True
|
||||
and kwargs.get("aimg_generation", False) != True
|
||||
): # allow users to control returning cached responses from the completion function
|
||||
# checking cache
|
||||
print_verbose(f"INSIDE CHECKING CACHE")
|
||||
|
@ -2435,6 +2446,7 @@ def client(original_function):
|
|||
result = None
|
||||
logging_obj = kwargs.get("litellm_logging_obj", None)
|
||||
# only set litellm_call_id if its not in kwargs
|
||||
call_type = original_function.__name__
|
||||
if "litellm_call_id" not in kwargs:
|
||||
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
||||
try:
|
||||
|
@ -2465,8 +2477,14 @@ def client(original_function):
|
|||
f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}"
|
||||
)
|
||||
# if caching is false, don't run this
|
||||
final_embedding_cached_response = None
|
||||
|
||||
if (
|
||||
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
||||
(
|
||||
kwargs.get("caching", None) is None
|
||||
and kwargs.get("cache", None) is None
|
||||
and litellm.cache is not None
|
||||
)
|
||||
or kwargs.get("caching", False) == True
|
||||
or (
|
||||
kwargs.get("cache", None) is not None
|
||||
|
@ -2481,8 +2499,36 @@ def client(original_function):
|
|||
in litellm.cache.supported_call_types
|
||||
):
|
||||
print_verbose(f"Checking Cache")
|
||||
if call_type == CallTypes.aembedding.value and isinstance(
|
||||
kwargs["input"], list
|
||||
):
|
||||
tasks = []
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
preset_cache_key = litellm.cache.get_cache_key(
|
||||
*args, **{**kwargs, "input": i}
|
||||
)
|
||||
tasks.append(
|
||||
litellm.cache.async_get_cache(
|
||||
cache_key=preset_cache_key
|
||||
)
|
||||
)
|
||||
cached_result = await asyncio.gather(*tasks)
|
||||
## check if cached result is None ##
|
||||
if cached_result is not None and isinstance(
|
||||
cached_result, list
|
||||
):
|
||||
if len(cached_result) == 1 and cached_result[0] is None:
|
||||
cached_result = None
|
||||
else:
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs[
|
||||
"preset_cache_key"
|
||||
] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
if cached_result != None:
|
||||
|
||||
if cached_result is not None and not isinstance(
|
||||
cached_result, list
|
||||
):
|
||||
print_verbose(f"Cache Hit!")
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.acompletion.value and isinstance(
|
||||
|
@ -2555,6 +2601,103 @@ def client(original_function):
|
|||
args=(cached_result, start_time, end_time, cache_hit),
|
||||
).start()
|
||||
return cached_result
|
||||
elif (
|
||||
call_type == CallTypes.aembedding.value
|
||||
and cached_result is not None
|
||||
and isinstance(cached_result, list)
|
||||
and litellm.cache is not None
|
||||
and not isinstance(
|
||||
litellm.cache.cache, S3Cache
|
||||
) # s3 doesn't support bulk writing. Exclude.
|
||||
):
|
||||
remaining_list = []
|
||||
non_null_list = []
|
||||
for idx, cr in enumerate(cached_result):
|
||||
if cr is None:
|
||||
remaining_list.append(kwargs["input"][idx])
|
||||
else:
|
||||
non_null_list.append((idx, cr))
|
||||
original_kwargs_input = kwargs["input"]
|
||||
kwargs["input"] = remaining_list
|
||||
if len(non_null_list) > 0:
|
||||
print_verbose(
|
||||
f"EMBEDDING CACHE HIT! - {len(non_null_list)}"
|
||||
)
|
||||
final_embedding_cached_response = EmbeddingResponse(
|
||||
model=kwargs.get("model"),
|
||||
data=[None] * len(original_kwargs_input),
|
||||
)
|
||||
final_embedding_cached_response._hidden_params[
|
||||
"cache_hit"
|
||||
] = True
|
||||
|
||||
for val in non_null_list:
|
||||
idx, cr = val # (idx, cr) tuple
|
||||
if cr is not None:
|
||||
final_embedding_cached_response.data[idx] = cr
|
||||
if len(remaining_list) == 0:
|
||||
# LOG SUCCESS
|
||||
cache_hit = True
|
||||
end_time = datetime.datetime.now()
|
||||
(
|
||||
model,
|
||||
custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
api_base,
|
||||
) = litellm.get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=kwargs.get(
|
||||
"custom_llm_provider", None
|
||||
),
|
||||
api_base=kwargs.get("api_base", None),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
)
|
||||
print_verbose(
|
||||
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
|
||||
)
|
||||
logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
user=kwargs.get("user", None),
|
||||
optional_params={},
|
||||
litellm_params={
|
||||
"logger_fn": kwargs.get("logger_fn", None),
|
||||
"acompletion": True,
|
||||
"metadata": kwargs.get("metadata", {}),
|
||||
"model_info": kwargs.get("model_info", {}),
|
||||
"proxy_server_request": kwargs.get(
|
||||
"proxy_server_request", None
|
||||
),
|
||||
"preset_cache_key": kwargs.get(
|
||||
"preset_cache_key", None
|
||||
),
|
||||
"stream_response": kwargs.get(
|
||||
"stream_response", {}
|
||||
),
|
||||
},
|
||||
input=kwargs.get("messages", ""),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
original_response=str(final_embedding_cached_response),
|
||||
additional_args=None,
|
||||
stream=kwargs.get("stream", False),
|
||||
)
|
||||
asyncio.create_task(
|
||||
logging_obj.async_success_handler(
|
||||
final_embedding_cached_response,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
)
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(
|
||||
final_embedding_cached_response,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
),
|
||||
).start()
|
||||
return final_embedding_cached_response
|
||||
# MODEL CALL
|
||||
result = await original_function(*args, **kwargs)
|
||||
end_time = datetime.datetime.now()
|
||||
|
@ -2586,13 +2729,29 @@ def client(original_function):
|
|||
):
|
||||
if isinstance(result, litellm.ModelResponse) or isinstance(
|
||||
result, litellm.EmbeddingResponse
|
||||
):
|
||||
if (
|
||||
isinstance(result, EmbeddingResponse)
|
||||
and isinstance(kwargs["input"], list)
|
||||
and litellm.cache is not None
|
||||
and not isinstance(
|
||||
litellm.cache.cache, S3Cache
|
||||
) # s3 doesn't support bulk writing. Exclude.
|
||||
):
|
||||
asyncio.create_task(
|
||||
litellm.cache._async_add_cache(result.json(), *args, **kwargs)
|
||||
litellm.cache.async_add_cache_pipeline(
|
||||
result, *args, **kwargs
|
||||
)
|
||||
)
|
||||
else:
|
||||
asyncio.create_task(
|
||||
litellm.cache._async_add_cache(result, *args, **kwargs)
|
||||
litellm.cache.async_add_cache(
|
||||
result.json(), *args, **kwargs
|
||||
)
|
||||
)
|
||||
else:
|
||||
asyncio.create_task(
|
||||
litellm.cache.async_add_cache(result, *args, **kwargs)
|
||||
)
|
||||
# LOG SUCCESS - handle streaming success logging in the _next_ object
|
||||
print_verbose(
|
||||
|
@ -2616,6 +2775,27 @@ def client(original_function):
|
|||
result._response_ms = (
|
||||
end_time - start_time
|
||||
).total_seconds() * 1000 # return response latency in ms like openai
|
||||
|
||||
if (
|
||||
isinstance(result, EmbeddingResponse)
|
||||
and final_embedding_cached_response is not None
|
||||
):
|
||||
idx = 0
|
||||
final_data_list = []
|
||||
for item in final_embedding_cached_response.data:
|
||||
if item is None:
|
||||
final_data_list.append(result.data[idx])
|
||||
idx += 1
|
||||
else:
|
||||
final_data_list.append(item)
|
||||
|
||||
final_embedding_cached_response.data = final_data_list
|
||||
final_embedding_cached_response._hidden_params["cache_hit"] = True
|
||||
final_embedding_cached_response._response_ms = (
|
||||
end_time - start_time
|
||||
).total_seconds() * 1000
|
||||
return final_embedding_cached_response
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
traceback_exception = traceback.format_exc()
|
||||
|
@ -3275,7 +3455,11 @@ def completion_cost(
|
|||
else:
|
||||
raise Exception(f"Model={model} not found in completion cost model map")
|
||||
# Calculate cost based on prompt_tokens, completion_tokens
|
||||
if "togethercomputer" in model or "together_ai" in model:
|
||||
if (
|
||||
"togethercomputer" in model
|
||||
or "together_ai" in model
|
||||
or custom_llm_provider == "together_ai"
|
||||
):
|
||||
# together ai prices based on size of llm
|
||||
# get_model_params_and_category takes a model name and returns the category of LLM size it is in model_prices_and_context_window.json
|
||||
model = get_model_params_and_category(model)
|
||||
|
@ -3864,7 +4048,7 @@ def get_optional_params(
|
|||
_check_valid_arg(supported_params=supported_params)
|
||||
|
||||
if stream:
|
||||
optional_params["stream_tokens"] = stream
|
||||
optional_params["stream"] = stream
|
||||
if temperature is not None:
|
||||
optional_params["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
|
@ -4498,6 +4682,14 @@ def get_llm_provider(
|
|||
# voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1
|
||||
api_base = "https://api.voyageai.com/v1"
|
||||
dynamic_api_key = get_secret("VOYAGE_API_KEY")
|
||||
elif custom_llm_provider == "together_ai":
|
||||
api_base = "https://api.together.xyz/v1"
|
||||
dynamic_api_key = (
|
||||
get_secret("TOGETHER_API_KEY")
|
||||
or get_secret("TOGETHER_AI_API_KEY")
|
||||
or get_secret("TOGETHERAI_API_KEY")
|
||||
or get_secret("TOGETHER_AI_TOKEN")
|
||||
)
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
elif model.split("/", 1)[0] in litellm.provider_list:
|
||||
custom_llm_provider = model.split("/", 1)[0]
|
||||
|
@ -6383,7 +6575,12 @@ def exception_type(
|
|||
message=f"BedrockException - {original_exception.message}",
|
||||
llm_provider="bedrock",
|
||||
model=model,
|
||||
response=original_exception.response,
|
||||
response=httpx.Response(
|
||||
status_code=500,
|
||||
request=httpx.Request(
|
||||
method="POST", url="https://api.openai.com/v1/"
|
||||
),
|
||||
),
|
||||
)
|
||||
elif original_exception.status_code == 401:
|
||||
exception_mapping_worked = True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue