fix(openai-proxy/utils.py): adding caching

This commit is contained in:
Krrish Dholakia 2023-10-23 17:00:56 -07:00
parent ea0c65d146
commit c34e9d73ff
5 changed files with 111 additions and 7 deletions

View file

@ -774,12 +774,13 @@ def client(original_function):
if (litellm.caching or litellm.caching_with_models) and litellm.cache is None: if (litellm.caching or litellm.caching_with_models) and litellm.cache is None:
litellm.cache = Cache() litellm.cache = Cache()
if kwargs.get("caching", False): # allow users to control returning cached responses from the completion function if kwargs.get("caching", False) or litellm.cache is not None: # allow users to control returning cached responses from the completion function
# checking cache # checking cache
if (litellm.cache != None or litellm.caching or litellm.caching_with_models): if (litellm.cache != None or litellm.caching or litellm.caching_with_models):
print_verbose(f"LiteLLM: Checking Cache") print_verbose(f"LiteLLM: Checking Cache")
cached_result = litellm.cache.get_cache(*args, **kwargs) cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result != None: if cached_result != None:
print_verbose(f"Cache Hit!")
return cached_result return cached_result
# MODEL CALL # MODEL CALL

View file

@ -1,5 +1,15 @@
OPENAI_API_KEY = "" OPENAI_API_KEY = ""
HUGGINGFACE_API_KEY=""
TOGETHERAI_API_KEY=""
REPLICATE_API_KEY=""
## bedrock / sagemaker
AWS_ACCESS_KEY_ID = ""
AWS_SECRET_ACCESS_KEY = ""
AZURE_API_KEY = "" AZURE_API_KEY = ""
AZURE_API_BASE = "" AZURE_API_BASE = ""
AZURE_API_VERSION = "" AZURE_API_VERSION = ""
@ -8,3 +18,17 @@ ANTHROPIC_API_KEY = ""
COHERE_API_KEY = "" COHERE_API_KEY = ""
## LOGGING ##
### LANGFUSE
LANGFUSE_PUBLIC_KEY = ""
LANGFUSE_SECRET_KEY = ""
# Optional, defaults to https://cloud.langfuse.com
LANGFUSE_HOST = "" # optional
## CACHING ##
### REDIS
REDIS_HOST = ""
REDIS_PORT = ""
REDIS_PASSWORD = ""

View file

@ -1,4 +1,4 @@
import litellm, os import litellm, os, traceback
from fastapi import FastAPI, Request, HTTPException from fastapi import FastAPI, Request, HTTPException
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from fastapi.responses import StreamingResponse, FileResponse from fastapi.responses import StreamingResponse, FileResponse
@ -21,7 +21,6 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
set_callbacks() # sets litellm callbacks for logging if they exist in the environment set_callbacks() # sets litellm callbacks for logging if they exist in the environment
#### API ENDPOINTS #### #### API ENDPOINTS ####
@router.post("/v1/models") @router.post("/v1/models")
@router.get("/models") # if project requires model list @router.get("/models") # if project requires model list
@ -65,8 +64,10 @@ async def chat_completion(request: Request):
data = await request.json() data = await request.json()
if "authorization" in request.headers: # if users pass LLM api keys as part of header if "authorization" in request.headers: # if users pass LLM api keys as part of header
api_key = request.headers.get("authorization") api_key = request.headers.get("authorization")
api_key = api_key.split(" ")[1] api_key = api_key.replace("Bearer", "").strip()
data["api_key"] = api_key if len(api_key.strip()) > 0:
api_key = api_key
data["api_key"] = api_key
response = litellm.completion( response = litellm.completion(
**data **data
) )
@ -74,7 +75,10 @@ async def chat_completion(request: Request):
return StreamingResponse(data_generator(response), media_type='text/event-stream') return StreamingResponse(data_generator(response), media_type='text/event-stream')
return response return response
except Exception as e: except Exception as e:
return HTTPException(status_code=500, detail=str(e)) error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
return {"error": error_msg}
# raise HTTPException(status_code=500, detail=error_msg)
@router.get("/") @router.get("/")
async def home(request: Request): async def home(request: Request):

View file

@ -0,0 +1,60 @@
import openai, os, dotenv, traceback, time
openai.api_base = "http://0.0.0.0:8000"
dotenv.load_dotenv()
openai.api_key = os.getenv("ANTHROPIC_API_KEY") # this gets passed as a header
response1 = openai.ChatCompletion.create(
model = "claude-instant-1",
messages = [
{
"role": "user",
"content": "this is a test message, what model / llm are you"
}
],
)
try:
print(f"response: {response1['choices'][0]['message']['content']}")
except:
print(f"response: {response1}")
time.sleep(1) # allow time for request to be stored
response2 = openai.ChatCompletion.create(
model = "claude-instant-1",
messages = [
{
"role": "user",
"content": "this is a test message, what model / llm are you"
}
],
)
try:
print(f"response: {response2['choices'][0]['message']['content']}")
except:
print(f"response: {response2}")
openai.api_key = os.getenv("OPENAI_API_KEY")
try:
response3 = openai.ChatCompletion.create(
model = "gpt-3.5-turbo",
messages = [
{
"role": "user",
"content": "this is a test message, what model / llm are you"
}
],
)
except Exception as e:
traceback.print_exc()
try:
print(f"response: {response3['choices'][0]['message']['content']}")
except:
print(f"response: {response3}")
assert response1["choices"][0]["message"]["content"] == response2["choices"][0]["message"]["content"]
assert response1["choices"][0]["message"]["content"] != response3["choices"][0]["message"]["content"]

View file

@ -3,5 +3,20 @@ import dotenv
dotenv.load_dotenv() # load env variables dotenv.load_dotenv() # load env variables
def set_callbacks(): def set_callbacks():
if ("LANGFUSE_PUBLIC_KEY" in os.environ and "LANGFUSE_SECRET_KEY" in os.environ) or "LANGFUSE_HOST" in os.environ: ## LOGGING
### LANGFUSE
if (len(os.getenv("LANGFUSE_PUBLIC_KEY", "")) > 0 and len(os.getenv("LANGFUSE_SECRET_KEY", ""))) > 0 or len(os.getenv("LANGFUSE_HOST", "")) > 0:
print(f"sets langfuse integration")
litellm.success_callback = ["langfuse"] litellm.success_callback = ["langfuse"]
## CACHING
### REDIS
print(f"redis host: {len(os.getenv('REDIS_HOST', ''))}; redis port: {len(os.getenv('REDIS_PORT', ''))}; redis password: {len(os.getenv('REDIS_PASSWORD'))}")
if len(os.getenv("REDIS_HOST", "")) > 0 and len(os.getenv("REDIS_PORT", "")) > 0 and len(os.getenv("REDIS_PASSWORD", "")) > 0:
print(f"sets caching integration")
from litellm.caching import Cache
litellm.cache = Cache(type="redis", host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))