(feat) proxy use ProxyException clas

This commit is contained in:
ishaan-jaff 2024-01-15 10:45:39 -08:00
parent 63e80056e5
commit b43cf53473

View file

@ -113,7 +113,9 @@ app = FastAPI(
) )
class OpenAIException(Exception): class ProxyException(Exception):
# NOTE: DO NOT MODIFY THIS
# This is used to map exactly to OPENAI Exceptions
def __init__( def __init__(
self, self,
message: str, message: str,
@ -127,8 +129,9 @@ class OpenAIException(Exception):
self.code = code self.code = code
@app.exception_handler(OpenAIException) @app.exception_handler(ProxyException)
async def openai_exception_handler(request: Request, exc: OpenAIException): async def openai_exception_handler(request: Request, exc: ProxyException):
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
return JSONResponse( return JSONResponse(
status_code=int(exc.code) status_code=int(exc.code)
if exc.code if exc.code
@ -1461,11 +1464,12 @@ async def completion(
traceback.print_exc() traceback.print_exc()
error_traceback = traceback.format_exc() error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}" error_msg = f"{str(e)}\n\n{error_traceback}"
try: raise ProxyException(
status = e.status_code # type: ignore message=getattr(e, "message", error_msg),
except: type=getattr(e, "type", "None"),
status = 500 param=getattr(e, "param", "None"),
raise HTTPException(status_code=status, detail=error_msg) code=getattr(e, "status_code", 500),
)
@router.post( @router.post(
@ -1650,7 +1654,7 @@ async def chat_completion(
error_traceback = traceback.format_exc() error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}" error_msg = f"{str(e)}\n\n{error_traceback}"
raise OpenAIException( raise ProxyException(
message=getattr(e, "message", error_msg), message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"), type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"), param=getattr(e, "param", "None"),
@ -1791,11 +1795,12 @@ async def embeddings(
else: else:
error_traceback = traceback.format_exc() error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}" error_msg = f"{str(e)}\n\n{error_traceback}"
try: raise ProxyException(
status = e.status_code # type: ignore message=getattr(e, "message", error_msg),
except: type=getattr(e, "type", "None"),
status = 500 param=getattr(e, "param", "None"),
raise HTTPException(status_code=status, detail=error_msg) code=getattr(e, "status_code", 500),
)
@router.post( @router.post(
@ -1905,11 +1910,12 @@ async def image_generation(
else: else:
error_traceback = traceback.format_exc() error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}" error_msg = f"{str(e)}\n\n{error_traceback}"
try: raise ProxyException(
status = e.status_code # type: ignore message=getattr(e, "message", error_msg),
except: type=getattr(e, "type", "None"),
status = 500 param=getattr(e, "param", "None"),
raise HTTPException(status_code=status, detail=error_msg) code=getattr(e, "status_code", 500),
)
#### KEY MANAGEMENT #### #### KEY MANAGEMENT ####