mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fix(openai-proxy/utils.py): adding caching
This commit is contained in:
parent
ea0c65d146
commit
c34e9d73ff
5 changed files with 111 additions and 7 deletions
|
@ -774,12 +774,13 @@ def client(original_function):
|
|||
if (litellm.caching or litellm.caching_with_models) and litellm.cache is None:
|
||||
litellm.cache = Cache()
|
||||
|
||||
if kwargs.get("caching", False): # allow users to control returning cached responses from the completion function
|
||||
if kwargs.get("caching", False) or litellm.cache is not None: # allow users to control returning cached responses from the completion function
|
||||
# checking cache
|
||||
if (litellm.cache != None or litellm.caching or litellm.caching_with_models):
|
||||
print_verbose(f"LiteLLM: Checking Cache")
|
||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
if cached_result != None:
|
||||
print_verbose(f"Cache Hit!")
|
||||
return cached_result
|
||||
|
||||
# MODEL CALL
|
||||
|
|
|
@ -1,5 +1,15 @@
|
|||
OPENAI_API_KEY = ""
|
||||
|
||||
HUGGINGFACE_API_KEY=""
|
||||
|
||||
TOGETHERAI_API_KEY=""
|
||||
|
||||
REPLICATE_API_KEY=""
|
||||
|
||||
## bedrock / sagemaker
|
||||
AWS_ACCESS_KEY_ID = ""
|
||||
AWS_SECRET_ACCESS_KEY = ""
|
||||
|
||||
AZURE_API_KEY = ""
|
||||
AZURE_API_BASE = ""
|
||||
AZURE_API_VERSION = ""
|
||||
|
@ -8,3 +18,17 @@ ANTHROPIC_API_KEY = ""
|
|||
|
||||
COHERE_API_KEY = ""
|
||||
|
||||
## LOGGING ##
|
||||
|
||||
### LANGFUSE
|
||||
LANGFUSE_PUBLIC_KEY = ""
|
||||
LANGFUSE_SECRET_KEY = ""
|
||||
# Optional, defaults to https://cloud.langfuse.com
|
||||
LANGFUSE_HOST = "" # optional
|
||||
|
||||
## CACHING ##
|
||||
|
||||
### REDIS
|
||||
REDIS_HOST = ""
|
||||
REDIS_PORT = ""
|
||||
REDIS_PASSWORD = ""
|
|
@ -1,4 +1,4 @@
|
|||
import litellm, os
|
||||
import litellm, os, traceback
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
|
@ -21,7 +21,6 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
set_callbacks() # sets litellm callbacks for logging if they exist in the environment
|
||||
|
||||
#### API ENDPOINTS ####
|
||||
@router.post("/v1/models")
|
||||
@router.get("/models") # if project requires model list
|
||||
|
@ -65,8 +64,10 @@ async def chat_completion(request: Request):
|
|||
data = await request.json()
|
||||
if "authorization" in request.headers: # if users pass LLM api keys as part of header
|
||||
api_key = request.headers.get("authorization")
|
||||
api_key = api_key.split(" ")[1]
|
||||
data["api_key"] = api_key
|
||||
api_key = api_key.replace("Bearer", "").strip()
|
||||
if len(api_key.strip()) > 0:
|
||||
api_key = api_key
|
||||
data["api_key"] = api_key
|
||||
response = litellm.completion(
|
||||
**data
|
||||
)
|
||||
|
@ -74,7 +75,10 @@ async def chat_completion(request: Request):
|
|||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||
return response
|
||||
except Exception as e:
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
return {"error": error_msg}
|
||||
# raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
||||
@router.get("/")
|
||||
async def home(request: Request):
|
||||
|
|
60
openai-proxy/tests/test_caching.py
Normal file
60
openai-proxy/tests/test_caching.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
import openai, os, dotenv, traceback, time
|
||||
openai.api_base = "http://0.0.0.0:8000"
|
||||
dotenv.load_dotenv()
|
||||
openai.api_key = os.getenv("ANTHROPIC_API_KEY") # this gets passed as a header
|
||||
|
||||
response1 = openai.ChatCompletion.create(
|
||||
model = "claude-instant-1",
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test message, what model / llm are you"
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
print(f"response: {response1['choices'][0]['message']['content']}")
|
||||
except:
|
||||
print(f"response: {response1}")
|
||||
|
||||
time.sleep(1) # allow time for request to be stored
|
||||
|
||||
response2 = openai.ChatCompletion.create(
|
||||
model = "claude-instant-1",
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test message, what model / llm are you"
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
print(f"response: {response2['choices'][0]['message']['content']}")
|
||||
except:
|
||||
print(f"response: {response2}")
|
||||
|
||||
openai.api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
try:
|
||||
response3 = openai.ChatCompletion.create(
|
||||
model = "gpt-3.5-turbo",
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test message, what model / llm are you"
|
||||
}
|
||||
],
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
try:
|
||||
print(f"response: {response3['choices'][0]['message']['content']}")
|
||||
except:
|
||||
print(f"response: {response3}")
|
||||
|
||||
assert response1["choices"][0]["message"]["content"] == response2["choices"][0]["message"]["content"]
|
||||
|
||||
assert response1["choices"][0]["message"]["content"] != response3["choices"][0]["message"]["content"]
|
|
@ -3,5 +3,20 @@ import dotenv
|
|||
dotenv.load_dotenv() # load env variables
|
||||
|
||||
def set_callbacks():
|
||||
if ("LANGFUSE_PUBLIC_KEY" in os.environ and "LANGFUSE_SECRET_KEY" in os.environ) or "LANGFUSE_HOST" in os.environ:
|
||||
## LOGGING
|
||||
### LANGFUSE
|
||||
if (len(os.getenv("LANGFUSE_PUBLIC_KEY", "")) > 0 and len(os.getenv("LANGFUSE_SECRET_KEY", ""))) > 0 or len(os.getenv("LANGFUSE_HOST", "")) > 0:
|
||||
print(f"sets langfuse integration")
|
||||
litellm.success_callback = ["langfuse"]
|
||||
|
||||
## CACHING
|
||||
### REDIS
|
||||
print(f"redis host: {len(os.getenv('REDIS_HOST', ''))}; redis port: {len(os.getenv('REDIS_PORT', ''))}; redis password: {len(os.getenv('REDIS_PASSWORD'))}")
|
||||
if len(os.getenv("REDIS_HOST", "")) > 0 and len(os.getenv("REDIS_PORT", "")) > 0 and len(os.getenv("REDIS_PASSWORD", "")) > 0:
|
||||
print(f"sets caching integration")
|
||||
from litellm.caching import Cache
|
||||
litellm.cache = Cache(type="redis", host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue