forked from phoenix/litellm-mirror
fix(proxy_server.py): fix linting issues
This commit is contained in:
parent
d3d8b86eaa
commit
2e8d582a34
5 changed files with 44 additions and 32 deletions
|
@ -156,8 +156,10 @@ class DualCache(BaseCache):
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
self.redis_cache.flush_cache()
|
if self.in_memory_cache is not None:
|
||||||
self.in_memory_cache.flush_cache()
|
self.in_memory_cache.flush_cache()
|
||||||
|
if self.redis_cache is not None:
|
||||||
|
self.redis_cache.flush_cache()
|
||||||
|
|
||||||
#### LiteLLM.Completion Cache ####
|
#### LiteLLM.Completion Cache ####
|
||||||
class Cache:
|
class Cache:
|
||||||
|
|
|
@ -28,7 +28,9 @@ from litellm.utils import (
|
||||||
completion_with_fallbacks,
|
completion_with_fallbacks,
|
||||||
get_llm_provider,
|
get_llm_provider,
|
||||||
get_api_key,
|
get_api_key,
|
||||||
mock_completion_streaming_obj
|
mock_completion_streaming_obj,
|
||||||
|
convert_to_model_response_object,
|
||||||
|
token_counter
|
||||||
)
|
)
|
||||||
from .llms import (
|
from .llms import (
|
||||||
anthropic,
|
anthropic,
|
||||||
|
@ -2145,7 +2147,7 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None):
|
||||||
|
|
||||||
# # Update usage information if needed
|
# # Update usage information if needed
|
||||||
if messages:
|
if messages:
|
||||||
response["usage"]["prompt_tokens"] = litellm.utils.token_counter(model=model, messages=messages)
|
response["usage"]["prompt_tokens"] = token_counter(model=model, messages=messages)
|
||||||
response["usage"]["completion_tokens"] = litellm.utils.token_counter(model=model, text=combined_content)
|
response["usage"]["completion_tokens"] = token_counter(model=model, text=combined_content)
|
||||||
response["usage"]["total_tokens"] = response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
|
response["usage"]["total_tokens"] = response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
|
||||||
return litellm.utils.convert_to_model_response_object(response_object=response, model_response_object=litellm.ModelResponse())
|
return convert_to_model_response_object(response_object=response, model_response_object=litellm.ModelResponse())
|
||||||
|
|
|
@ -277,10 +277,10 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
if isinstance(litellm.success_callback, list):
|
if isinstance(litellm.success_callback, list):
|
||||||
import utils
|
import utils
|
||||||
print("setting litellm success callback to track cost")
|
print("setting litellm success callback to track cost")
|
||||||
if (utils.track_cost_callback) not in litellm.success_callback:
|
if (utils.track_cost_callback) not in litellm.success_callback: # type: ignore
|
||||||
litellm.success_callback.append(utils.track_cost_callback)
|
litellm.success_callback.append(utils.track_cost_callback) # type: ignore
|
||||||
else:
|
else:
|
||||||
litellm.success_callback = utils.track_cost_callback
|
litellm.success_callback = utils.track_cost_callback # type: ignore
|
||||||
### START REDIS QUEUE ###
|
### START REDIS QUEUE ###
|
||||||
use_queue = general_settings.get("use_queue", False)
|
use_queue = general_settings.get("use_queue", False)
|
||||||
celery_setup(use_queue=use_queue)
|
celery_setup(use_queue=use_queue)
|
||||||
|
@ -717,32 +717,40 @@ async def test_endpoint(request: Request):
|
||||||
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def async_queue_request(request: Request):
|
async def async_queue_request(request: Request):
|
||||||
global celery_fn, llm_model_list
|
global celery_fn, llm_model_list
|
||||||
body = await request.body()
|
if celery_fn is not None:
|
||||||
body_str = body.decode()
|
body = await request.body()
|
||||||
try:
|
body_str = body.decode()
|
||||||
data = ast.literal_eval(body_str)
|
try:
|
||||||
except:
|
data = ast.literal_eval(body_str)
|
||||||
data = json.loads(body_str)
|
except:
|
||||||
data["model"] = (
|
data = json.loads(body_str)
|
||||||
general_settings.get("completion_model", None) # server default
|
data["model"] = (
|
||||||
or user_model # model name passed via cli args
|
general_settings.get("completion_model", None) # server default
|
||||||
or data["model"] # default passed in http request
|
or user_model # model name passed via cli args
|
||||||
)
|
or data["model"] # default passed in http request
|
||||||
data["llm_model_list"] = llm_model_list
|
)
|
||||||
print(f"data: {data}")
|
data["llm_model_list"] = llm_model_list
|
||||||
job = celery_fn.apply_async(kwargs=data)
|
print(f"data: {data}")
|
||||||
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"}
|
job = celery_fn.apply_async(kwargs=data)
|
||||||
pass
|
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"}
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={"error": "Queue not initialized"},
|
||||||
|
)
|
||||||
|
|
||||||
@router.get("/queue/response/{task_id}", dependencies=[Depends(user_api_key_auth)])
|
@router.get("/queue/response/{task_id}", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def async_queue_response(request: Request, task_id: str):
|
async def async_queue_response(request: Request, task_id: str):
|
||||||
global celery_app_conn, async_result
|
global celery_app_conn, async_result
|
||||||
try:
|
try:
|
||||||
job = async_result(task_id, app=celery_app_conn)
|
if celery_app_conn is not None and async_result is not None:
|
||||||
if job.ready():
|
job = async_result(task_id, app=celery_app_conn)
|
||||||
return {"status": "finished", "result": job.result}
|
if job.ready():
|
||||||
|
return {"status": "finished", "result": job.result}
|
||||||
|
else:
|
||||||
|
return {'status': 'queued'}
|
||||||
else:
|
else:
|
||||||
return {'status': 'queued'}
|
raise Exception()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"status": "finished", "result": str(e)}
|
return {"status": "finished", "result": str(e)}
|
||||||
|
|
||||||
|
|
|
@ -38,9 +38,9 @@ model_list = [{ # list of model deployments
|
||||||
router = Router(model_list=model_list,
|
router = Router(model_list=model_list,
|
||||||
redis_host=os.getenv("REDIS_HOST"),
|
redis_host=os.getenv("REDIS_HOST"),
|
||||||
redis_password=os.getenv("REDIS_PASSWORD"),
|
redis_password=os.getenv("REDIS_PASSWORD"),
|
||||||
redis_port=int(os.getenv("REDIS_PORT")),
|
redis_port=int(os.getenv("REDIS_PORT")), # type: ignore
|
||||||
routing_strategy="simple-shuffle",
|
routing_strategy="simple-shuffle",
|
||||||
set_verbose=True,
|
set_verbose=False,
|
||||||
num_retries=1) # type: ignore
|
num_retries=1) # type: ignore
|
||||||
kwargs = {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey, how's it going?"}],}
|
kwargs = {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey, how's it going?"}],}
|
||||||
|
|
||||||
|
|
|
@ -101,7 +101,7 @@ def test_async_fallbacks():
|
||||||
|
|
||||||
asyncio.run(test_get_response())
|
asyncio.run(test_get_response())
|
||||||
|
|
||||||
# test_async_fallbacks()
|
test_async_fallbacks()
|
||||||
|
|
||||||
def test_sync_context_window_fallbacks():
|
def test_sync_context_window_fallbacks():
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue