From 42e0d7cf68690bc133d1a74a4aa41fc715487a1c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 9 Oct 2023 14:56:05 -0700 Subject: [PATCH] fix(proxy_server): returns better error messages for invalid api errors --- litellm/__pycache__/main.cpython-311.pyc | Bin 51217 -> 51091 bytes litellm/proxy/proxy_server.py | 133 +++++++++++------------ litellm/tests/test_bad_params.py | 6 +- 3 files changed, 67 insertions(+), 72 deletions(-) diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index be268cfe4b601c11e3baa5e2dabc84107fccfe06..1fe305e41795c0da3ec22ebb29beaa137f91e43f 100644 GIT binary patch delta 4321 zcmbssYfw|!`JN;<35gI$$O{M|gb+f)%YewskN_g!v!W|1yKX%$S~V)*W9-S|;-k8v z8~xCtr7pD&-Hvp}Hf`P6Zd;k^uHAMlI_Nqbe>mIiyVJ(iZgt1m_S_SK0;ZMenYrKN z+;hM4z0P;Odn_o1H^hqX#K$W*I=VM~UvW(*3cKL9x)VqW5A+VWqE|zKAqB~y(cnaJ zu*2X4lR*t3I;Vtdl-o`@r7@*6zKt6caolGisB4nCKKg35JXX=q%hZz93Rp!j`lYSd`+c(_qm=u{e1n+oD~BHGve8 zfF&En!iMb`VA&`iTR9;p< z-Kl-@JwCgE!zn}r2&WP;g=s`Wp^5MmrV}ZJW+J07gUBhgkT?o6iGo5aQBs&iW>9D& z@j8mKiHbryNuV%?a1_SD@zg5hfRR)ug^yD$C>PSxYT}$&5>fYfwide55|9gyrWJYe z1|`(R^oa~at;3p#KZ2-zNv9?HhU0MlAfjvO6B%0(*1#8OPE-K7rZQ9r*t7)xnx=xw zrl0UL`^7_U$hVnbvn827OTC7m(4u#$Ng_!ic3j*qBerwsC<-{Ky*BDG3$bCdF)EzH z$#A#G$Qw!KPzhYJEMp55IU_MQ~})?AMlm^BC7X7i+fyWgPf?bktE?N7__vYIma`NXcnYo+Vsz* zw+sH2&R~LnC0@XMmQ0gi0j#udkS&aa9?qEbXc2sDU*uRkNIiaj-0J~r%oH4O8>A!? zFBn=vb5JW2wutv%gx;KW=u0qEUBvysS(ZaiK7S} z<>WwB?sWHRj_$=YpW#0ScXR8|ayVLGLN74=H#(O)RtzesJLnS`|J5Bvz9lNtE$FZG z|6`K=LM+&zi7TV}uNEYZ9=!4WV-w_X`p}%dgU|8>Hp@JS*vf zUWf}S{H-vrpbovlRA{3rv`sUJ>!VJ9ErnA4eQM?g*j?BF=L=MDqtHx^`LwVpo@s+O zvg)x38jD(FW4b^_rg_@s!#d^Q1?B$hbnDF#54T^$o54S`3HjmOnN?^Di?7+e!m)L5 z26X^^BIAF=cXQM_ppHV0JJIib|f9^*hJ!jGa#vj${~c8~Kh&`{?ODyT&XyNYwCo zaVZ#QRX4>d0~M2RqOoraVwW$-C^x`quZG_h-J6{_ss9j!XPrR1;Xr8#`YOCvy2L@I z<+H#9dA$FZuywWxy$VNXH=!=(A5``%|HQhR1pmAmm8lfs!21iRfA+vzWpz!j4aU(- zq0gjD`F}jJXTok0|Ig`e-EgqH&Jk<3{K@`@&&u6=_ar&1shqvAup+jcsU3x2{obe? zmW2|(P-FK+B1}7;%aCE!E5#&{b~pP2iL|9!tNlpXk4Ru=rdGL=8fI*WuVhp8uK2kIs^xI0s?F;86qja9ETJY#+=H$}v=l*2~? z)x=ZFwVDH7Qcb)uFWaM#Sfhj7Iq|kw7x%>Lc=a4-61#};8-x=YZgw5_&H0`_*e{-N z0gL<<)St&-wa2VHKBn6MW0 zgbpdCkBqHQb^a=PPH1oZ^@-P0H;UtMwE60TX5ZS>Q~316#sz<)pU3?)E4~Rgd{wK? zJjo&abPM(7v*!a1q#<@a!c*cmp^+M>(W3!)r=|vd3m(^4(6>YBwI3=&|5|hhp}&Mi zms}B}+o99TBr>!r^yccDGBgO=aarh(I0>QUp-1cfrZ2y($^1x@eM6Hyu(QsG9jaC8w#$Ia>@F_|NI{as>RT%&Qad{q(es3a;yZITL@sLeW_ye!$y1 zUE&8WUZ9JZ_ z;rb%4md!NCYO)wH-!+rr1%>81rq%FLxlh8Tmc`fdjA+p)ye%c-;T9A1P9epGbiO3C zY40h74uqup?@7>b=;5I>gsz9Qhc6%$487k|BuDq*Veg0NR_GTe4$9G^(EKxBD3CqW z^xd0M)E}~37!sk2p=&?LMd)hiQRr=iejR%2NBgAogS_|7$-CC8kRr?p@U}7`EM}~f zQ)$97M#{Owotet8Job(NDkKZc)}{iCUWw`pm<*(oWxi) z7q1OA_bd!+7^~xwjD6Z(Rah^uU10;KH}`Gt-4Zsk*D0JKV}S4N3a2tQjZ4yoO^i(E zsDfrjW^%gpzO}urVJl;EIF&xw(X%=1U~I0Sowe5#b~4t*CFp`{ds@SJj4kG*ns5mt hso9pO_4i~Vk>~+O!96i2lJC=n delta 4219 zcmbUjX>e1=`K_d9*|M$6)@4bSELpZ>;S2e|_>z2pF#$}u52v9l0dok~R~WOZjWGdS z7X+UPPGV9w0p~y|X*Hd+(?6vXNJ<;VbTBY&(#cGw(+)#ZF(sWTZ8F`rlE4Sh&^PnG zYv23s-EZNz_(V`Fdowjv#?i;x_7mAvt;m}Xf6$&%NhU-bcb>k{C!mIYogpm&&z7sG z+>MmNsyGoRLA%ZcX1yBD>oQO>T-Ldvn{r}(21}M-S(hprp#TX-GD3RzRz_~0VWyryHp%Xg;S7@9 z!*wKJV?c>bR9|*L-KBZ>KR(ySVKWf{!kI)&p@m2&v=W}eERsN>jYui9lSB%$NfLz) zBBL;eBvY76QYdtiR4qj=qM$I3$SHIaj=~g>8>`Vgc*f|aaGx;?<-;q+dQ<>^qLM<8 znTj<<6B1fD`bHl_L&92UG3nBZClFnsZ}e$LRNiZHqxmplsz4t2%(N1fK#|=HE6gFj zbW9v7JDFKb^ZT{c2+J}P?d4-q;y8qwA`Wab#5HnQ2a_I65_J}vh%K}L zB3bj{qfDbgP*yS9v7p+Y3;)Tq*cV1?)R|0OHN|KiR}!v=3!;u;qHH%oze#6Tc4a*< z#o?+s;Zz5d5VUCFm}M`jVSyN}@9?!_qEH!Jv3gh}7T!cu*bBC-4af(dnYEHO6smxi zvK+c^8j>gaW0UaXoS2k;Lbxurp`yO-%#Qoq!*wR%Iw7h~9@n$zL>&%nf&p7?+9RTY z>$XiSj+^3!@mwfAtU--n%|6608j}d6F~f)1=`^TMvMu(-WBfNRO;fBi`OxNgRwpQ% z69#dEwM?(ZOCQ!YP(TWZT@cIhGY`dh85HE0(Q;7MWJ7z7DtUPojhAjDJ9Y~z`73bq z%&kOMR|ol5|@97PyhqhE{?%*9_fG1?1Jwx|OfOgaG z+>6n(kUsBo^lfJ2Icg&Zs`3}3EijnBgf`68eB-+BjEUn71s042x6X9u=pv7T;UHLjaaffzH(1Z9qJB)>}5f_yC z>*0ffM%2dK-pLZZqVQ=Gy(^v`MUp>6dun%V0Q?;yCKu@Jdb9_;Mfo%t{-UL+%vJ2C z$CV22h08@7&Sb;Y%+r2cWLwI`~z zQ;lB$eTg0QKvjtw?GwradrF<~ZpovH)B9jD^iDe_gvnswdzgimGNa`ys(*jX>mx+n zo-w1#2+PZ@GWn0w5|cMc;!3C zrfIU=h%UgLax*#rvIR@g0CPI9;896={A@As2jbck8k&@Yu|Rfqy-CYhQECefP9)LR zpzo~K_&-u`keQ;Vno8n7M~xkd8Efr&TQKH+ys?7W`G$^#9{x~VpH$+nqxxQiD+?Q4 zk8}7@w*>?LVqAwx?7u^G9FF=#PqTNRztW73z-yH&-A5nx5~;>RGq~?h<4|ZAGO8Tt z2jH!0V>uC#+EpA52E3%Ufort{e8fw9(fN1cp+P$T zCaZOjRi>Lg7E@~6DmD!9Nx}(tST+)yYTnfWJ$t7QL6N_b#ytWlwU*>lG3(=$FRNXO zUSdu4cC8n!g4?wfPq?HQ_>p)frO1DXYCH}5ygvRkJrO>n)iuCPZ?*_-25wUhGCz*HW*k$7SYQ!yKHwRP&ydI|F`gjK@Eu#Tv^fY;330d;!V}aIa z#)YM&UWeX*9rg9-EQIQ_&_wuZ{gq@zM;i^HV}}5?QxG0q@d-kI3wN))Bu0M>SFVvr z(dMvs^OO{Q09stJK%Rbl=g`hk!+66rx$C;z)xY?bT7Nq4WZr1&TMgIL#n;uv{Y!5t zG{<{~dPnod@im3}y29Px6kd+i2(1op*>*=~`PAmTZd>w^b8_{D$xT}(w{~2!bzZl1 z1{Jq-Ig`qq@Zms(q;DR*&R;4{^KGc)F7wTH0Z+Ob^a7p|IoG(ksXW9Ow`^^Jc&fBy ztxJ5v#Z&r5L9VcHQ@h9^;7_X>jRO8$U#AoBpNr~J1^kzqx-0?Twl+Bh{EsA`R>1eH zrmkwP?_L#Cf3I5B?Gb+|&FwA_e_6l_w1=txr?K85eEzRWpFzO?*3{Vq{MCY}`itI# zx>`YSu|A=WeQ~kT>MLZD#Y!Lh_F|LAr(u%jqk_hMih)ya*4{IcrXxAF;>l`YJ#hWmq*eVtL0RNQO$@Vq7&G@h@R6~M!QG$ zMht8@gVS5b`H{Yek+CLDrHPmsnaNQHEsV5r+RV|dBdrlTW1XBr7wjC~6LB#%Pw>t@ zVve{OJCBoVgIkANBl(Oi;S$m!rHrIuXT_txD;0@E_c#jfiaAl@pf-}ozF8mcxRfRO EFZaF+=l}o! diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ca1ee6f23..28bd22223 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -118,6 +118,66 @@ def data_generator(response): print_verbose(f"returned chunk: {chunk}") yield f"data: {json.dumps(chunk)}\n\n" +def litellm_completion(data, type): + try: + if user_model: + data["model"] = user_model + # override with user settings + if user_temperature: + data["temperature"] = user_temperature + if user_max_tokens: + data["max_tokens"] = user_max_tokens + if user_api_base: + data["api_base"] = user_api_base + ## CUSTOM PROMPT TEMPLATE ## - run `litellm --config` to set this + litellm.register_prompt_template( + model=user_model, + roles={ + "system": { + "pre_message": os.getenv("MODEL_SYSTEM_MESSAGE_START_TOKEN", ""), + "post_message": os.getenv("MODEL_SYSTEM_MESSAGE_END_TOKEN", ""), + }, + "assistant": { + "pre_message": os.getenv("MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""), + "post_message": os.getenv("MODEL_ASSISTANT_MESSAGE_END_TOKEN", "") + }, + "user": { + "pre_message": os.getenv("MODEL_USER_MESSAGE_START_TOKEN", ""), + "post_message": os.getenv("MODEL_USER_MESSAGE_END_TOKEN", "") + } + }, + initial_prompt_value=os.getenv("MODEL_PRE_PROMPT", ""), + final_prompt_value=os.getenv("MODEL_POST_PROMPT", "") + ) + if type == "completion": + response = litellm.text_completion(**data) + elif type == "chat_completion": + response = litellm.completion(**data) + if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses + return StreamingResponse(data_generator(response), media_type='text/event-stream') + print_verbose(f"response: {response}") + return response + except Exception as e: + if "Invalid response object from API" in str(e): + completion_call_details = {} + if user_model: + completion_call_details["model"] = user_model + else: + completion_call_details["model"] = data['model'] + + if user_api_base: + completion_call_details["api_base"] = user_api_base + else: + completion_call_details["api_base"] = None + print(f"\033[1;31mLiteLLM.Exception: Invalid API Call. Call details: Model: \033[1;37m{completion_call_details['model']}\033[1;31m; LLM Provider: \033[1;37m{e.llm_provider}\033[1;31m; Custom API Base - \033[1;37m{completion_call_details['api_base']}\033[1;31m\033[0m") + if completion_call_details["api_base"] == "http://localhost:11434": + print() + print("Trying to call ollama? Try `litellm --model ollama/llama2 --api_base http://localhost:11434`") + print() + else: + print(f"\033[1;31mLiteLLM.Exception: {str(e)}\033[0m") + return {"message": "An error occurred"}, 500 + #### API ENDPOINTS #### @router.get("/models") # if project requires model list def model_list(): @@ -136,82 +196,15 @@ def model_list(): @router.post("/completions") async def completion(request: Request): data = await request.json() - print_verbose(f"data passed in: {data}") - if user_model: - data["model"] = user_model - if user_api_base: - data["api_base"] = user_api_base - # override with user settings - if user_temperature: - data["temperature"] = user_temperature - if user_max_tokens: - data["max_tokens"] = user_max_tokens - - ## check for custom prompt template ## - litellm.register_prompt_template( - model=user_model, - roles={ - "system": { - "pre_message": os.getenv("MODEL_SYSTEM_MESSAGE_START_TOKEN", ""), - "post_message": os.getenv("MODEL_SYSTEM_MESSAGE_END_TOKEN", ""), - }, - "assistant": { - "pre_message": os.getenv("MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""), - "post_message": os.getenv("MODEL_ASSISTANT_MESSAGE_END_TOKEN", "") - }, - "user": { - "pre_message": os.getenv("MODEL_USER_MESSAGE_START_TOKEN", ""), - "post_message": os.getenv("MODEL_USER_MESSAGE_END_TOKEN", "") - } - }, - initial_prompt_value=os.getenv("MODEL_PRE_PROMPT", ""), - final_prompt_value=os.getenv("MODEL_POST_PROMPT", "") - ) - response = litellm.text_completion(**data) - if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses - return StreamingResponse(data_generator(response), media_type='text/event-stream') - return response + return litellm_completion(data=data, type="completion") @router.post("/chat/completions") async def chat_completion(request: Request): data = await request.json() print_verbose(f"data passed in: {data}") - if user_model: - data["model"] = user_model - # override with user settings - if user_temperature: - data["temperature"] = user_temperature - if user_max_tokens: - data["max_tokens"] = user_max_tokens - if user_api_base: - data["api_base"] = user_api_base - ## check for custom prompt template ## - litellm.register_prompt_template( - model=user_model, - roles={ - "system": { - "pre_message": os.getenv("MODEL_SYSTEM_MESSAGE_START_TOKEN", ""), - "post_message": os.getenv("MODEL_SYSTEM_MESSAGE_END_TOKEN", ""), - }, - "assistant": { - "pre_message": os.getenv("MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""), - "post_message": os.getenv("MODEL_ASSISTANT_MESSAGE_END_TOKEN", "") - }, - "user": { - "pre_message": os.getenv("MODEL_USER_MESSAGE_START_TOKEN", ""), - "post_message": os.getenv("MODEL_USER_MESSAGE_END_TOKEN", "") - } - }, - initial_prompt_value=os.getenv("MODEL_PRE_PROMPT", ""), - final_prompt_value=os.getenv("MODEL_POST_PROMPT", "") - ) - response = litellm.completion(**data) - + response = litellm_completion(data, type="chat_completion") # track cost of this response, using litellm.completion_cost - await track_cost(response) - if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses - return StreamingResponse(data_generator(response), media_type='text/event-stream') - print_verbose(f"response: {response}") + track_cost(response) return response async def track_cost(response): diff --git a/litellm/tests/test_bad_params.py b/litellm/tests/test_bad_params.py index b563b02af..886c139f3 100644 --- a/litellm/tests/test_bad_params.py +++ b/litellm/tests/test_bad_params.py @@ -54,17 +54,19 @@ def test_completion_invalid_param_cohere(): else: pytest.fail(f'An error occurred {e}') -test_completion_invalid_param_cohere() +# test_completion_invalid_param_cohere()s def test_completion_function_call_cohere(): try: - response = completion(model="command-nightly", messages=messages, function_call="TEST-FUNCTION") + response = completion(model="command-nightly", messages=messages, functions=["TEST-FUNCTION"]) except Exception as e: if "Function calling is not supported by this provider" in str(e): pass else: pytest.fail(f'An error occurred {e}') +test_completion_function_call_cohere() + def test_completion_function_call_openai(): try: messages = [{"role": "user", "content": "What is the weather like in Boston?"}]