forked from phoenix/litellm-mirror
fix(openai-proxy/utils.py): adding caching
This commit is contained in:
parent
ea0c65d146
commit
c34e9d73ff
5 changed files with 111 additions and 7 deletions
|
@ -774,12 +774,13 @@ def client(original_function):
|
||||||
if (litellm.caching or litellm.caching_with_models) and litellm.cache is None:
|
if (litellm.caching or litellm.caching_with_models) and litellm.cache is None:
|
||||||
litellm.cache = Cache()
|
litellm.cache = Cache()
|
||||||
|
|
||||||
if kwargs.get("caching", False): # allow users to control returning cached responses from the completion function
|
if kwargs.get("caching", False) or litellm.cache is not None: # allow users to control returning cached responses from the completion function
|
||||||
# checking cache
|
# checking cache
|
||||||
if (litellm.cache != None or litellm.caching or litellm.caching_with_models):
|
if (litellm.cache != None or litellm.caching or litellm.caching_with_models):
|
||||||
print_verbose(f"LiteLLM: Checking Cache")
|
print_verbose(f"LiteLLM: Checking Cache")
|
||||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||||
if cached_result != None:
|
if cached_result != None:
|
||||||
|
print_verbose(f"Cache Hit!")
|
||||||
return cached_result
|
return cached_result
|
||||||
|
|
||||||
# MODEL CALL
|
# MODEL CALL
|
||||||
|
|
|
@ -1,5 +1,15 @@
|
||||||
OPENAI_API_KEY = ""
|
OPENAI_API_KEY = ""
|
||||||
|
|
||||||
|
HUGGINGFACE_API_KEY=""
|
||||||
|
|
||||||
|
TOGETHERAI_API_KEY=""
|
||||||
|
|
||||||
|
REPLICATE_API_KEY=""
|
||||||
|
|
||||||
|
## bedrock / sagemaker
|
||||||
|
AWS_ACCESS_KEY_ID = ""
|
||||||
|
AWS_SECRET_ACCESS_KEY = ""
|
||||||
|
|
||||||
AZURE_API_KEY = ""
|
AZURE_API_KEY = ""
|
||||||
AZURE_API_BASE = ""
|
AZURE_API_BASE = ""
|
||||||
AZURE_API_VERSION = ""
|
AZURE_API_VERSION = ""
|
||||||
|
@ -8,3 +18,17 @@ ANTHROPIC_API_KEY = ""
|
||||||
|
|
||||||
COHERE_API_KEY = ""
|
COHERE_API_KEY = ""
|
||||||
|
|
||||||
|
## LOGGING ##
|
||||||
|
|
||||||
|
### LANGFUSE
|
||||||
|
LANGFUSE_PUBLIC_KEY = ""
|
||||||
|
LANGFUSE_SECRET_KEY = ""
|
||||||
|
# Optional, defaults to https://cloud.langfuse.com
|
||||||
|
LANGFUSE_HOST = "" # optional
|
||||||
|
|
||||||
|
## CACHING ##
|
||||||
|
|
||||||
|
### REDIS
|
||||||
|
REDIS_HOST = ""
|
||||||
|
REDIS_PORT = ""
|
||||||
|
REDIS_PASSWORD = ""
|
|
@ -1,4 +1,4 @@
|
||||||
import litellm, os
|
import litellm, os, traceback
|
||||||
from fastapi import FastAPI, Request, HTTPException
|
from fastapi import FastAPI, Request, HTTPException
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from fastapi.responses import StreamingResponse, FileResponse
|
from fastapi.responses import StreamingResponse, FileResponse
|
||||||
|
@ -21,7 +21,6 @@ app.add_middleware(
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
set_callbacks() # sets litellm callbacks for logging if they exist in the environment
|
set_callbacks() # sets litellm callbacks for logging if they exist in the environment
|
||||||
|
|
||||||
#### API ENDPOINTS ####
|
#### API ENDPOINTS ####
|
||||||
@router.post("/v1/models")
|
@router.post("/v1/models")
|
||||||
@router.get("/models") # if project requires model list
|
@router.get("/models") # if project requires model list
|
||||||
|
@ -65,8 +64,10 @@ async def chat_completion(request: Request):
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
if "authorization" in request.headers: # if users pass LLM api keys as part of header
|
if "authorization" in request.headers: # if users pass LLM api keys as part of header
|
||||||
api_key = request.headers.get("authorization")
|
api_key = request.headers.get("authorization")
|
||||||
api_key = api_key.split(" ")[1]
|
api_key = api_key.replace("Bearer", "").strip()
|
||||||
data["api_key"] = api_key
|
if len(api_key.strip()) > 0:
|
||||||
|
api_key = api_key
|
||||||
|
data["api_key"] = api_key
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
**data
|
**data
|
||||||
)
|
)
|
||||||
|
@ -74,7 +75,10 @@ async def chat_completion(request: Request):
|
||||||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return HTTPException(status_code=500, detail=str(e))
|
error_traceback = traceback.format_exc()
|
||||||
|
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||||
|
return {"error": error_msg}
|
||||||
|
# raise HTTPException(status_code=500, detail=error_msg)
|
||||||
|
|
||||||
@router.get("/")
|
@router.get("/")
|
||||||
async def home(request: Request):
|
async def home(request: Request):
|
||||||
|
|
60
openai-proxy/tests/test_caching.py
Normal file
60
openai-proxy/tests/test_caching.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
import openai, os, dotenv, traceback, time
|
||||||
|
openai.api_base = "http://0.0.0.0:8000"
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
openai.api_key = os.getenv("ANTHROPIC_API_KEY") # this gets passed as a header
|
||||||
|
|
||||||
|
response1 = openai.ChatCompletion.create(
|
||||||
|
model = "claude-instant-1",
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test message, what model / llm are you"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"response: {response1['choices'][0]['message']['content']}")
|
||||||
|
except:
|
||||||
|
print(f"response: {response1}")
|
||||||
|
|
||||||
|
time.sleep(1) # allow time for request to be stored
|
||||||
|
|
||||||
|
response2 = openai.ChatCompletion.create(
|
||||||
|
model = "claude-instant-1",
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test message, what model / llm are you"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"response: {response2['choices'][0]['message']['content']}")
|
||||||
|
except:
|
||||||
|
print(f"response: {response2}")
|
||||||
|
|
||||||
|
openai.api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response3 = openai.ChatCompletion.create(
|
||||||
|
model = "gpt-3.5-turbo",
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test message, what model / llm are you"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"response: {response3['choices'][0]['message']['content']}")
|
||||||
|
except:
|
||||||
|
print(f"response: {response3}")
|
||||||
|
|
||||||
|
assert response1["choices"][0]["message"]["content"] == response2["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
assert response1["choices"][0]["message"]["content"] != response3["choices"][0]["message"]["content"]
|
|
@ -3,5 +3,20 @@ import dotenv
|
||||||
dotenv.load_dotenv() # load env variables
|
dotenv.load_dotenv() # load env variables
|
||||||
|
|
||||||
def set_callbacks():
|
def set_callbacks():
|
||||||
if ("LANGFUSE_PUBLIC_KEY" in os.environ and "LANGFUSE_SECRET_KEY" in os.environ) or "LANGFUSE_HOST" in os.environ:
|
## LOGGING
|
||||||
|
### LANGFUSE
|
||||||
|
if (len(os.getenv("LANGFUSE_PUBLIC_KEY", "")) > 0 and len(os.getenv("LANGFUSE_SECRET_KEY", ""))) > 0 or len(os.getenv("LANGFUSE_HOST", "")) > 0:
|
||||||
|
print(f"sets langfuse integration")
|
||||||
litellm.success_callback = ["langfuse"]
|
litellm.success_callback = ["langfuse"]
|
||||||
|
|
||||||
|
## CACHING
|
||||||
|
### REDIS
|
||||||
|
print(f"redis host: {len(os.getenv('REDIS_HOST', ''))}; redis port: {len(os.getenv('REDIS_PORT', ''))}; redis password: {len(os.getenv('REDIS_PASSWORD'))}")
|
||||||
|
if len(os.getenv("REDIS_HOST", "")) > 0 and len(os.getenv("REDIS_PORT", "")) > 0 and len(os.getenv("REDIS_PASSWORD", "")) > 0:
|
||||||
|
print(f"sets caching integration")
|
||||||
|
from litellm.caching import Cache
|
||||||
|
litellm.cache = Cache(type="redis", host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue