mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(proxy_server.py): hash keys
This commit is contained in:
parent
a4c9e18eb5
commit
6b1b1b82cf
5 changed files with 135 additions and 132 deletions
|
@ -1,71 +1,87 @@
|
|||
import litellm
|
||||
from litellm import ModelResponse
|
||||
from proxy_server import llm_model_list
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Any
|
||||
import os, subprocess, hashlib
|
||||
|
||||
def track_cost_callback(
|
||||
kwargs, # kwargs to completion
|
||||
completion_response: ModelResponse, # response from completion
|
||||
start_time = None,
|
||||
end_time = None, # start/end time for completion
|
||||
):
|
||||
try:
|
||||
# init logging config
|
||||
print("in custom callback tracking cost", llm_model_list)
|
||||
if "azure" in kwargs["model"]:
|
||||
# for azure cost tracking, we check the provided model list in the config.yaml
|
||||
# we need to map azure/chatgpt-deployment to -> azure/gpt-3.5-turbo
|
||||
pass
|
||||
# check if it has collected an entire stream response
|
||||
if "complete_streaming_response" in kwargs:
|
||||
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
||||
completion_response=kwargs["complete_streaming_response"]
|
||||
input_text = kwargs["messages"]
|
||||
output_text = completion_response["choices"][0]["message"]["content"]
|
||||
response_cost = litellm.completion_cost(
|
||||
model = kwargs["model"],
|
||||
messages = input_text,
|
||||
completion=output_text
|
||||
class PrismaClient:
|
||||
def __init__(self, database_url: str):
|
||||
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||
os.environ["DATABASE_URL"] = database_url
|
||||
subprocess.run(['prisma', 'generate'])
|
||||
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
|
||||
# Now you can import the Prisma Client
|
||||
from prisma import Client
|
||||
self.db = Client() #Client to connect to Prisma db
|
||||
|
||||
def hash_token(self, token: str):
|
||||
# Hash the string using SHA-256
|
||||
hashed_token = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
return hashed_token
|
||||
|
||||
async def get_data(self, token: str, expires: Optional[Any]=None):
|
||||
hashed_token = self.hash_token(token=token)
|
||||
if expires:
|
||||
response = await self.db.litellm_verificationtoken.find_first(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"expires": {"gte": expires} # Check if the token is not expired
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = await self.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": hashed_token
|
||||
}
|
||||
)
|
||||
print("streaming response_cost", response_cost)
|
||||
# for non streaming responses
|
||||
else:
|
||||
# we pass the completion_response obj
|
||||
if kwargs["stream"] != True:
|
||||
input_text = kwargs.get("messages", "")
|
||||
if isinstance(input_text, list):
|
||||
input_text = "".join(m["content"] for m in input_text)
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
|
||||
print("regular response_cost", response_cost)
|
||||
except:
|
||||
pass
|
||||
return response
|
||||
|
||||
def update_prisma_database(token, response_cost):
|
||||
try:
|
||||
# Import your Prisma client
|
||||
from your_prisma_module import prisma
|
||||
async def insert_data(self, data: dict):
|
||||
"""
|
||||
Add a key to the database. If it already exists, do nothing.
|
||||
"""
|
||||
token = data["token"]
|
||||
hashed_token = self.hash_token(token=token)
|
||||
data["token"] = hashed_token
|
||||
print(f"passed in data: {data}; hashed_token: {hashed_token}")
|
||||
|
||||
# Fetch the existing cost for the given token
|
||||
existing_cost = prisma.LiteLLM_VerificationToken.find_unique(
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
"token": token
|
||||
}
|
||||
).cost
|
||||
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_cost = existing_cost + response_cost
|
||||
|
||||
# Update the cost column for the given token
|
||||
prisma_liteLLM_VerificationToken = prisma.LiteLLM_VerificationToken.update(
|
||||
where={
|
||||
"token": token
|
||||
'token': hashed_token,
|
||||
},
|
||||
data={
|
||||
"cost": new_cost
|
||||
"create": {**data}, #type: ignore
|
||||
"update": {} # don't do anything if it already exists
|
||||
}
|
||||
)
|
||||
print(f"Prisma database updated for token {token}. New cost: {new_cost}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error updating Prisma database: {e}")
|
||||
pass
|
||||
return new_verification_token
|
||||
|
||||
async def update_data(self, token: str, data: dict):
|
||||
"""
|
||||
Update existing data
|
||||
"""
|
||||
hashed_token = self.hash_token(token=token)
|
||||
data["token"] = hashed_token
|
||||
await self.db.litellm_verificationtoken.update(
|
||||
where={
|
||||
"token": hashed_token
|
||||
},
|
||||
data={**data}
|
||||
)
|
||||
return {"token": token, "data": data}
|
||||
|
||||
async def delete_data(self, tokens: List):
|
||||
"""
|
||||
Allow user to delete a key(s)
|
||||
"""
|
||||
hashed_tokens = [self.hash_token(token=token) for token in tokens]
|
||||
await self.db.litellm_verificationtoken.delete_many(
|
||||
where={"token": {"in": hashed_tokens}}
|
||||
)
|
||||
return {"deleted_keys": tokens}
|
||||
|
||||
async def connect(self):
|
||||
await self.db.connect()
|
||||
|
||||
async def disconnect(self):
|
||||
await self.db.disconnect()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue