(fix) improve mem util

This commit is contained in:
ishaan-jaff 2024-03-11 16:22:04 -07:00
parent a97e8a9029
commit b617263860
2 changed files with 149 additions and 14 deletions

View file

@ -210,9 +210,6 @@ class Router:
self.context_window_fallbacks = ( self.context_window_fallbacks = (
context_window_fallbacks or litellm.context_window_fallbacks context_window_fallbacks or litellm.context_window_fallbacks
) )
self.model_exception_map: dict = (
{}
) # dict to store model: list exceptions. self.exceptions = {"gpt-3.5": ["API KEY Error", "Rate Limit Error", "good morning error"]}
self.total_calls: defaultdict = defaultdict( self.total_calls: defaultdict = defaultdict(
int int
) # dict to store total calls made to each model ) # dict to store total calls made to each model
@ -1487,17 +1484,6 @@ class Router:
self._set_cooldown_deployments( self._set_cooldown_deployments(
deployment_id deployment_id
) # setting deployment_id in cooldown deployments ) # setting deployment_id in cooldown deployments
if metadata:
deployment = metadata.get("deployment", None)
deployment_exceptions = self.model_exception_map.get(deployment, [])
deployment_exceptions.append(exception_str)
self.model_exception_map[deployment] = deployment_exceptions
verbose_router_logger.debug("\nEXCEPTION FOR DEPLOYMENTS\n")
verbose_router_logger.debug(self.model_exception_map)
for model in self.model_exception_map:
verbose_router_logger.debug(
f"Model {model} had {len(self.model_exception_map[model])} exception"
)
if custom_llm_provider: if custom_llm_provider:
model_name = f"{custom_llm_provider}/{model_name}" model_name = f"{custom_llm_provider}/{model_name}"

View file

@ -0,0 +1,149 @@
#### What this tests ####
from memory_profiler import profile, memory_usage
import sys, os, time
import traceback, asyncio
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import Router
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from dotenv import load_dotenv
import uuid
import tracemalloc
import objgraph
objgraph.growth(shortnames=True)
objgraph.show_most_common_types(limit=10)
from mem_top import mem_top
load_dotenv()
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
{
"model_name": "bad-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
{
"model_name": "text-embedding-ada-002",
"litellm_params": {
"model": "azure/azure-embedding-model",
"api_key": os.environ["AZURE_API_KEY"],
"api_base": os.environ["AZURE_API_BASE"],
},
"tpm": 100000,
"rpm": 10000,
},
]
litellm.set_verbose = True
litellm.cache = litellm.Cache(
type="s3", s3_bucket_name="litellm-my-test-bucket-2", s3_region_name="us-east-1"
)
router = Router(
model_list=model_list,
fallbacks=[
{"bad-model": ["gpt-3.5-turbo"]},
],
) # type: ignore
async def router_acompletion():
# embedding call
question = f"This is a test: {uuid.uuid4()}" * 1
response = await router.acompletion(
model="bad-model", messages=[{"role": "user", "content": question}]
)
print("completion-resp", response)
return response
async def main():
for i in range(1):
start = time.time()
n = 15 # Number of concurrent tasks
tasks = [router_acompletion() for _ in range(n)]
chat_completions = await asyncio.gather(*tasks)
successful_completions = [c for c in chat_completions if c is not None]
# Write errors to error_log.txt
with open("error_log.txt", "a") as error_log:
for completion in chat_completions:
if isinstance(completion, str):
error_log.write(completion + "\n")
print(n, time.time() - start, len(successful_completions))
print()
print(vars(router))
if __name__ == "__main__":
# Blank out contents of error_log.txt
open("error_log.txt", "w").close()
import tracemalloc
tracemalloc.start(25)
# ... run your application ...
asyncio.run(main())
print(mem_top())
snapshot = tracemalloc.take_snapshot()
# top_stats = snapshot.statistics('lineno')
# print("[ Top 10 ]")
# for stat in top_stats[:50]:
# print(stat)
top_stats = snapshot.statistics("traceback")
# pick the biggest memory block
stat = top_stats[0]
print("%s memory blocks: %.1f KiB" % (stat.count, stat.size / 1024))
for line in stat.traceback.format():
print(line)
print()
stat = top_stats[1]
print("%s memory blocks: %.1f KiB" % (stat.count, stat.size / 1024))
for line in stat.traceback.format():
print(line)
print()
stat = top_stats[2]
print("%s memory blocks: %.1f KiB" % (stat.count, stat.size / 1024))
for line in stat.traceback.format():
print(line)
print()
stat = top_stats[3]
print("%s memory blocks: %.1f KiB" % (stat.count, stat.size / 1024))
for line in stat.traceback.format():
print(line)