mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Merge pull request #2461 from BerriAI/litellm_improve_mem_use
LiteLLM - improve memory utilization - don't use inMemCache on Router
This commit is contained in:
commit
fa655d62fb
2 changed files with 149 additions and 14 deletions
|
@ -210,9 +210,6 @@ class Router:
|
|||
self.context_window_fallbacks = (
|
||||
context_window_fallbacks or litellm.context_window_fallbacks
|
||||
)
|
||||
self.model_exception_map: dict = (
|
||||
{}
|
||||
) # dict to store model: list exceptions. self.exceptions = {"gpt-3.5": ["API KEY Error", "Rate Limit Error", "good morning error"]}
|
||||
self.total_calls: defaultdict = defaultdict(
|
||||
int
|
||||
) # dict to store total calls made to each model
|
||||
|
@ -1524,17 +1521,6 @@ class Router:
|
|||
self._set_cooldown_deployments(
|
||||
deployment_id
|
||||
) # setting deployment_id in cooldown deployments
|
||||
if metadata:
|
||||
deployment = metadata.get("deployment", None)
|
||||
deployment_exceptions = self.model_exception_map.get(deployment, [])
|
||||
deployment_exceptions.append(exception_str)
|
||||
self.model_exception_map[deployment] = deployment_exceptions
|
||||
verbose_router_logger.debug("\nEXCEPTION FOR DEPLOYMENTS\n")
|
||||
verbose_router_logger.debug(self.model_exception_map)
|
||||
for model in self.model_exception_map:
|
||||
verbose_router_logger.debug(
|
||||
f"Model {model} had {len(self.model_exception_map[model])} exception"
|
||||
)
|
||||
if custom_llm_provider:
|
||||
model_name = f"{custom_llm_provider}/{model_name}"
|
||||
|
||||
|
|
149
litellm/tests/test_mem_usage.py
Normal file
149
litellm/tests/test_mem_usage.py
Normal file
|
@ -0,0 +1,149 @@
|
|||
# #### What this tests ####
|
||||
|
||||
# from memory_profiler import profile, memory_usage
|
||||
# import sys, os, time
|
||||
# import traceback, asyncio
|
||||
# import pytest
|
||||
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../..")
|
||||
# ) # Adds the parent directory to the system path
|
||||
# import litellm
|
||||
# from litellm import Router
|
||||
# from concurrent.futures import ThreadPoolExecutor
|
||||
# from collections import defaultdict
|
||||
# from dotenv import load_dotenv
|
||||
# import uuid
|
||||
# import tracemalloc
|
||||
# import objgraph
|
||||
|
||||
# objgraph.growth(shortnames=True)
|
||||
# objgraph.show_most_common_types(limit=10)
|
||||
|
||||
# from mem_top import mem_top
|
||||
|
||||
# load_dotenv()
|
||||
|
||||
|
||||
# model_list = [
|
||||
# {
|
||||
# "model_name": "gpt-3.5-turbo", # openai model name
|
||||
# "litellm_params": { # params for litellm completion/embedding call
|
||||
# "model": "azure/chatgpt-v-2",
|
||||
# "api_key": os.getenv("AZURE_API_KEY"),
|
||||
# "api_version": os.getenv("AZURE_API_VERSION"),
|
||||
# "api_base": os.getenv("AZURE_API_BASE"),
|
||||
# },
|
||||
# "tpm": 240000,
|
||||
# "rpm": 1800,
|
||||
# },
|
||||
# {
|
||||
# "model_name": "bad-model", # openai model name
|
||||
# "litellm_params": { # params for litellm completion/embedding call
|
||||
# "model": "azure/chatgpt-v-2",
|
||||
# "api_key": "bad-key",
|
||||
# "api_version": os.getenv("AZURE_API_VERSION"),
|
||||
# "api_base": os.getenv("AZURE_API_BASE"),
|
||||
# },
|
||||
# "tpm": 240000,
|
||||
# "rpm": 1800,
|
||||
# },
|
||||
# {
|
||||
# "model_name": "text-embedding-ada-002",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/azure-embedding-model",
|
||||
# "api_key": os.environ["AZURE_API_KEY"],
|
||||
# "api_base": os.environ["AZURE_API_BASE"],
|
||||
# },
|
||||
# "tpm": 100000,
|
||||
# "rpm": 10000,
|
||||
# },
|
||||
# ]
|
||||
# litellm.set_verbose = True
|
||||
# litellm.cache = litellm.Cache(
|
||||
# type="s3", s3_bucket_name="litellm-my-test-bucket-2", s3_region_name="us-east-1"
|
||||
# )
|
||||
# router = Router(
|
||||
# model_list=model_list,
|
||||
# fallbacks=[
|
||||
# {"bad-model": ["gpt-3.5-turbo"]},
|
||||
# ],
|
||||
# ) # type: ignore
|
||||
|
||||
|
||||
# async def router_acompletion():
|
||||
# # embedding call
|
||||
# question = f"This is a test: {uuid.uuid4()}" * 1
|
||||
|
||||
# response = await router.acompletion(
|
||||
# model="bad-model", messages=[{"role": "user", "content": question}]
|
||||
# )
|
||||
# print("completion-resp", response)
|
||||
# return response
|
||||
|
||||
|
||||
# async def main():
|
||||
# for i in range(1):
|
||||
# start = time.time()
|
||||
# n = 15 # Number of concurrent tasks
|
||||
# tasks = [router_acompletion() for _ in range(n)]
|
||||
|
||||
# chat_completions = await asyncio.gather(*tasks)
|
||||
|
||||
# successful_completions = [c for c in chat_completions if c is not None]
|
||||
|
||||
# # Write errors to error_log.txt
|
||||
# with open("error_log.txt", "a") as error_log:
|
||||
# for completion in chat_completions:
|
||||
# if isinstance(completion, str):
|
||||
# error_log.write(completion + "\n")
|
||||
|
||||
# print(n, time.time() - start, len(successful_completions))
|
||||
# print()
|
||||
# print(vars(router))
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# # Blank out contents of error_log.txt
|
||||
# open("error_log.txt", "w").close()
|
||||
|
||||
# import tracemalloc
|
||||
|
||||
# tracemalloc.start(25)
|
||||
|
||||
# # ... run your application ...
|
||||
|
||||
# asyncio.run(main())
|
||||
# print(mem_top())
|
||||
|
||||
# snapshot = tracemalloc.take_snapshot()
|
||||
# # top_stats = snapshot.statistics('lineno')
|
||||
|
||||
# # print("[ Top 10 ]")
|
||||
# # for stat in top_stats[:50]:
|
||||
# # print(stat)
|
||||
|
||||
# top_stats = snapshot.statistics("traceback")
|
||||
|
||||
# # pick the biggest memory block
|
||||
# stat = top_stats[0]
|
||||
# print("%s memory blocks: %.1f KiB" % (stat.count, stat.size / 1024))
|
||||
# for line in stat.traceback.format():
|
||||
# print(line)
|
||||
# print()
|
||||
# stat = top_stats[1]
|
||||
# print("%s memory blocks: %.1f KiB" % (stat.count, stat.size / 1024))
|
||||
# for line in stat.traceback.format():
|
||||
# print(line)
|
||||
|
||||
# print()
|
||||
# stat = top_stats[2]
|
||||
# print("%s memory blocks: %.1f KiB" % (stat.count, stat.size / 1024))
|
||||
# for line in stat.traceback.format():
|
||||
# print(line)
|
||||
# print()
|
||||
|
||||
# stat = top_stats[3]
|
||||
# print("%s memory blocks: %.1f KiB" % (stat.count, stat.size / 1024))
|
||||
# for line in stat.traceback.format():
|
||||
# print(line)
|
Loading…
Add table
Add a link
Reference in a new issue