fix(proxy_server.py): fix linting issues

This commit is contained in:
Krrish Dholakia 2023-11-24 11:38:53 -08:00
parent d3d8b86eaa
commit 2e8d582a34
5 changed files with 44 additions and 32 deletions

View file

@ -156,8 +156,10 @@ class DualCache(BaseCache):
traceback.print_exc() traceback.print_exc()
def flush_cache(self): def flush_cache(self):
self.redis_cache.flush_cache() if self.in_memory_cache is not None:
self.in_memory_cache.flush_cache() self.in_memory_cache.flush_cache()
if self.redis_cache is not None:
self.redis_cache.flush_cache()
#### LiteLLM.Completion Cache #### #### LiteLLM.Completion Cache ####
class Cache: class Cache:

View file

@ -28,7 +28,9 @@ from litellm.utils import (
completion_with_fallbacks, completion_with_fallbacks,
get_llm_provider, get_llm_provider,
get_api_key, get_api_key,
mock_completion_streaming_obj mock_completion_streaming_obj,
convert_to_model_response_object,
token_counter
) )
from .llms import ( from .llms import (
anthropic, anthropic,
@ -2145,7 +2147,7 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None):
# # Update usage information if needed # # Update usage information if needed
if messages: if messages:
response["usage"]["prompt_tokens"] = litellm.utils.token_counter(model=model, messages=messages) response["usage"]["prompt_tokens"] = token_counter(model=model, messages=messages)
response["usage"]["completion_tokens"] = litellm.utils.token_counter(model=model, text=combined_content) response["usage"]["completion_tokens"] = token_counter(model=model, text=combined_content)
response["usage"]["total_tokens"] = response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"] response["usage"]["total_tokens"] = response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
return litellm.utils.convert_to_model_response_object(response_object=response, model_response_object=litellm.ModelResponse()) return convert_to_model_response_object(response_object=response, model_response_object=litellm.ModelResponse())

View file

@ -277,10 +277,10 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
if isinstance(litellm.success_callback, list): if isinstance(litellm.success_callback, list):
import utils import utils
print("setting litellm success callback to track cost") print("setting litellm success callback to track cost")
if (utils.track_cost_callback) not in litellm.success_callback: if (utils.track_cost_callback) not in litellm.success_callback: # type: ignore
litellm.success_callback.append(utils.track_cost_callback) litellm.success_callback.append(utils.track_cost_callback) # type: ignore
else: else:
litellm.success_callback = utils.track_cost_callback litellm.success_callback = utils.track_cost_callback # type: ignore
### START REDIS QUEUE ### ### START REDIS QUEUE ###
use_queue = general_settings.get("use_queue", False) use_queue = general_settings.get("use_queue", False)
celery_setup(use_queue=use_queue) celery_setup(use_queue=use_queue)
@ -717,32 +717,40 @@ async def test_endpoint(request: Request):
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)]) @router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
async def async_queue_request(request: Request): async def async_queue_request(request: Request):
global celery_fn, llm_model_list global celery_fn, llm_model_list
body = await request.body() if celery_fn is not None:
body_str = body.decode() body = await request.body()
try: body_str = body.decode()
data = ast.literal_eval(body_str) try:
except: data = ast.literal_eval(body_str)
data = json.loads(body_str) except:
data["model"] = ( data = json.loads(body_str)
general_settings.get("completion_model", None) # server default data["model"] = (
or user_model # model name passed via cli args general_settings.get("completion_model", None) # server default
or data["model"] # default passed in http request or user_model # model name passed via cli args
) or data["model"] # default passed in http request
data["llm_model_list"] = llm_model_list )
print(f"data: {data}") data["llm_model_list"] = llm_model_list
job = celery_fn.apply_async(kwargs=data) print(f"data: {data}")
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"} job = celery_fn.apply_async(kwargs=data)
pass return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"}
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Queue not initialized"},
)
@router.get("/queue/response/{task_id}", dependencies=[Depends(user_api_key_auth)]) @router.get("/queue/response/{task_id}", dependencies=[Depends(user_api_key_auth)])
async def async_queue_response(request: Request, task_id: str): async def async_queue_response(request: Request, task_id: str):
global celery_app_conn, async_result global celery_app_conn, async_result
try: try:
job = async_result(task_id, app=celery_app_conn) if celery_app_conn is not None and async_result is not None:
if job.ready(): job = async_result(task_id, app=celery_app_conn)
return {"status": "finished", "result": job.result} if job.ready():
return {"status": "finished", "result": job.result}
else:
return {'status': 'queued'}
else: else:
return {'status': 'queued'} raise Exception()
except Exception as e: except Exception as e:
return {"status": "finished", "result": str(e)} return {"status": "finished", "result": str(e)}

View file

@ -38,9 +38,9 @@ model_list = [{ # list of model deployments
router = Router(model_list=model_list, router = Router(model_list=model_list,
redis_host=os.getenv("REDIS_HOST"), redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"), redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=int(os.getenv("REDIS_PORT")), redis_port=int(os.getenv("REDIS_PORT")), # type: ignore
routing_strategy="simple-shuffle", routing_strategy="simple-shuffle",
set_verbose=True, set_verbose=False,
num_retries=1) # type: ignore num_retries=1) # type: ignore
kwargs = {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey, how's it going?"}],} kwargs = {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey, how's it going?"}],}

View file

@ -101,7 +101,7 @@ def test_async_fallbacks():
asyncio.run(test_get_response()) asyncio.run(test_get_response())
# test_async_fallbacks() test_async_fallbacks()
def test_sync_context_window_fallbacks(): def test_sync_context_window_fallbacks():
try: try: