fix(proxy_server.py): fix linting issues

This commit is contained in:
Krrish Dholakia 2023-11-24 11:38:53 -08:00
parent d3d8b86eaa
commit 2e8d582a34
5 changed files with 44 additions and 32 deletions

View file

@ -156,8 +156,10 @@ class DualCache(BaseCache):
traceback.print_exc()
def flush_cache(self):
self.redis_cache.flush_cache()
self.in_memory_cache.flush_cache()
if self.in_memory_cache is not None:
self.in_memory_cache.flush_cache()
if self.redis_cache is not None:
self.redis_cache.flush_cache()
#### LiteLLM.Completion Cache ####
class Cache:

View file

@ -28,7 +28,9 @@ from litellm.utils import (
completion_with_fallbacks,
get_llm_provider,
get_api_key,
mock_completion_streaming_obj
mock_completion_streaming_obj,
convert_to_model_response_object,
token_counter
)
from .llms import (
anthropic,
@ -2145,7 +2147,7 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None):
# # Update usage information if needed
if messages:
response["usage"]["prompt_tokens"] = litellm.utils.token_counter(model=model, messages=messages)
response["usage"]["completion_tokens"] = litellm.utils.token_counter(model=model, text=combined_content)
response["usage"]["prompt_tokens"] = token_counter(model=model, messages=messages)
response["usage"]["completion_tokens"] = token_counter(model=model, text=combined_content)
response["usage"]["total_tokens"] = response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
return litellm.utils.convert_to_model_response_object(response_object=response, model_response_object=litellm.ModelResponse())
return convert_to_model_response_object(response_object=response, model_response_object=litellm.ModelResponse())

View file

@ -277,10 +277,10 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
if isinstance(litellm.success_callback, list):
import utils
print("setting litellm success callback to track cost")
if (utils.track_cost_callback) not in litellm.success_callback:
litellm.success_callback.append(utils.track_cost_callback)
if (utils.track_cost_callback) not in litellm.success_callback: # type: ignore
litellm.success_callback.append(utils.track_cost_callback) # type: ignore
else:
litellm.success_callback = utils.track_cost_callback
litellm.success_callback = utils.track_cost_callback # type: ignore
### START REDIS QUEUE ###
use_queue = general_settings.get("use_queue", False)
celery_setup(use_queue=use_queue)
@ -717,32 +717,40 @@ async def test_endpoint(request: Request):
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
async def async_queue_request(request: Request):
global celery_fn, llm_model_list
body = await request.body()
body_str = body.decode()
try:
data = ast.literal_eval(body_str)
except:
data = json.loads(body_str)
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or data["model"] # default passed in http request
)
data["llm_model_list"] = llm_model_list
print(f"data: {data}")
job = celery_fn.apply_async(kwargs=data)
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"}
pass
if celery_fn is not None:
body = await request.body()
body_str = body.decode()
try:
data = ast.literal_eval(body_str)
except:
data = json.loads(body_str)
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or data["model"] # default passed in http request
)
data["llm_model_list"] = llm_model_list
print(f"data: {data}")
job = celery_fn.apply_async(kwargs=data)
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"}
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "Queue not initialized"},
)
@router.get("/queue/response/{task_id}", dependencies=[Depends(user_api_key_auth)])
async def async_queue_response(request: Request, task_id: str):
global celery_app_conn, async_result
try:
job = async_result(task_id, app=celery_app_conn)
if job.ready():
return {"status": "finished", "result": job.result}
if celery_app_conn is not None and async_result is not None:
job = async_result(task_id, app=celery_app_conn)
if job.ready():
return {"status": "finished", "result": job.result}
else:
return {'status': 'queued'}
else:
return {'status': 'queued'}
raise Exception()
except Exception as e:
return {"status": "finished", "result": str(e)}

View file

@ -38,9 +38,9 @@ model_list = [{ # list of model deployments
router = Router(model_list=model_list,
redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=int(os.getenv("REDIS_PORT")),
redis_port=int(os.getenv("REDIS_PORT")), # type: ignore
routing_strategy="simple-shuffle",
set_verbose=True,
set_verbose=False,
num_retries=1) # type: ignore
kwargs = {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey, how's it going?"}],}

View file

@ -101,7 +101,7 @@ def test_async_fallbacks():
asyncio.run(test_get_response())
# test_async_fallbacks()
test_async_fallbacks()
def test_sync_context_window_fallbacks():
try: