From 9155ba068fa3b80c4e67d78e7b21cc1e07dd5a11 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 26 Aug 2023 15:47:07 -0700 Subject: [PATCH] fix anthropic and together ai streaming --- litellm/__pycache__/main.cpython-311.pyc | Bin 28803 -> 28957 bytes litellm/__pycache__/utils.cpython-311.pyc | Bin 67002 -> 67157 bytes litellm/main.py | 19 +++-- litellm/tests/test_streaming.py | 93 +++++++++++++++++++--- litellm/utils.py | 21 ++--- 5 files changed, 105 insertions(+), 28 deletions(-) diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index 61e0b648e5f383f2c4916688546a686f9c07a67d..be35f5b19360c1d8ada21831d22b3684f43ddb9d 100644 GIT binary patch delta 1943 zcmb7EeN0nV6o03s1)=4mEv-<#D#)jT$N+&BC{U3(5gcx!LIvIf3Z*TtPt^IUICYDQ zF`Gx{jF5nF8WnZ!#l>YYO18y6wrprLA!K^Zs}L$79gT4Y>pguV}C>)!OtRw z!6!AV9uW&RM$)4rku9)LdK?)OB3HXoCAAji&@7GWYXL4u_YASJ05b#A<6T5M7;lF> zX(rxET^A=@0yrnlB<=>oFQT4p?v>G8yb5-$$P6&=0)G?VOa3cf*g}tI7uzlt9btxL z$AMu^1Ady5ic=}dtPYkPl8Z?nf>=Dp9hVau`Ohw23EL?q4KVy-F^TQO@ocRsU8oCM z88n+RLPUmb#+GCQ?oEitB-1>oK&xrKp-U_s0A3^Pc*bVfZm|@TJ7U$PY**Z~V>IdG zq>Cmq7_GrG67Z=|65KHi>%e?%tf9(R#)hyx&g#AzpSjQGvj8KOjV#5+`6g3skS&A+ zHHPU~7SzXj(U+P9PP{b#Rk?Yj4zH1YeoU| zle^_8W#quBK)#bPIz_cu&GI;EGJqG)nC7A-^CSp(H*+;?L`#+x25@_p0Z!ndtPI$M zm$QSOqOX>8fpM1l3b8*r4tylbN~KuKZezFOxt#nt+sH%=4#~%M;H{W!)GmlI63dWm zEcrNT&FN?DA*D2p&{>%m$-en~9DHfmpUA}~`b;j3a3+`$smV?;fesJ;y+8?rs90## zo_lT;8Bs4T8{37=3r+Brv~^*Bm`>ZYvewqL%(lvIt+O|+TyCo`O~Lzxg+`e_q{r)a z@eSm@U+oomuC325xO+L57YmA>KojN_50UDBU3?hc#?B=pFe-(X$N|2=h|)qBMQf=O z)<|bd3uSSBg@<>z1b>Lf?G^mt4uRwCPKS&0;&0~pa1eEtbmjXb4u1$MExQtq(2$`X zo?{&yf0Y346?GN1)n8GQti#QJ zOBGG`#;z`htE;oEgR^%y9h^(xCn$L@1^P<<80E=%&g)mxr`w6ItRoNXoX<(Q^W~Kocm1fV9y-Nmb zre%%Np2~Nkj6bSkCNyOenI#l8W!F^Z8!GdJ$~>is$F&WX-00+yyosoM^83hr#eJjb zwTZQ@*Na*widt_Ob0>9%L-qUXuj$fm=+Z{GD~Xu9a;|(zrIijh91C5yLc<(gA8#e% zJBuxj%G-uI9hLVN*b1n8kW#&v${B5=k;?I|WRD-LnSm#`ulXg#Q}XVI($(faWDtQ@)-O_iPAhg9AFf{tBWP*K0XaHaY7~*= zqxwYpxMhzLPNJE89nRnZHn(1)Wt(H`Yv)*&w1~ezZ*QeAPJxcFf3Dj@O4{bIb3s~g zCeBjM9SS!oe1g$jR`gw}{zM^#!XvEXQWaVfK3eMMTylNDX9Ve!owTvtO`3e2@bFRh zB3LJ7I<`aNEThBi>U4D3J#9iaA4NSSQMiL&^b{+M#9jOT-d4DT_TF-%o^1I9YW|Xd zKcvI$*}^O6{v2NHeOy3`#M27!G>_T$Q5XCMYDlM$M&QpdGa2@cPOhBKk?uJ6%g9G- wlRLL+V>pCQGcRBbFiP#T4PH13m>-17WU}8G0uL07ENqv4Iy^ve;a&I8KhhU0kpKVy delta 1726 zcmZ`%TTGNk6#mb@7cRTYZEwH=?iS>Z6>oraVL|W;3R+_jfx`Yju)<#WcU?ty7p#~l zu}$?zFN!Ip6|17w{`R41t;IAx^r416A!!q_?SnBUO9D|-A3C$kMw_OS%y-V5`R06c z=FEjhaBL1#msF})hK$a#$K2RW)wrC!myI{&0QgJc0 zil#0}yG_kv`f<=t`E*=U4@>XjmT`N1jFCt6ju@667-&aN(m1ijZv6LI`ArhrbMW6sDrvlnE6G(HgvJO4N4x z)_Xxn@(e^+48wKdv?&ET#V4jV01Ix*Oo47JR%`WD{&G&r*?Oz(8~ktCYy4(leC1=c zcskPr8}UMB20OVQA1}{?9=uavL35TCHsiEPi-GI}NrkTh8?$uS92JjaS#sElm$Ggf zwncQkjPS%$7Fm&7ge28r$==>Nd%fTCe{%LEIkm(x77^6Jb|P>ED|7UQ9g(#E5pmGE z5*z&09z%pfq@DOt&T4RxtQB}ISBcu(WzZNZWNU6R^x<5#9t1p+I|Xhu=j{d$)lcOP z_)2XsnfC8j2gY*|pXTMnc$fN}fR*t&BX@AA$M17qzs{HF%J)cL-C+MN?8_)oADRHt zKr+L_U+cMy(^?&h`In&|TMA5Y7zYXra0Jg6WP*>}VuQrb1yEL49OL&f+`y=0au6#E z%R^Crp?<0H%~5eIcND)ZJgB-FNlwas7tKZK5Wwc5D)qHjaq{tWzRi=z@J5jdj*HVp zebKmQ-5NCr4TO}(&Pw;YosQ4~C)WK9yK%f^l)O)6=_hakAC!*4xOlis0&pJ#>({~< z-df)adqh+DT2>*Xk`%L2m(%S*r+GP?#SwG1e1rsr4{^#oly;8pN4W%^vspc}igw$o z;x@aJ<9pr0sMubi0mU6kewzZ~V5J7`;Yeles-NjO*6I{~p-xiwfvzr_qpP#s%C}j2 zZM?%H+^0uQQJ`fApQF-}tT;>ECHXfkI!#Dk>E#6ve<-XgjAD}|L-&vnXX85B-Moi) zL{eVH_blDo+j@o>WELBMQ6`-%Jykkb6J&Rau~ozIn`Yv*A8UpKCj!6gOCFlCXWq!2 z$*m4X$+Ws4Lm`?mB=tTMt$Uxdpu~MGE9H0TOZ|XnS~kL`_@Jc# z7?icv$Q0zaz7)5&o|MCvIK%CMNnFXVFn&#+E2h@o!E+qxgK&$UK0@I~3fD2fXC%-+ zC(Kd!mBKeT$>$~fN!>pv7%42^Uwo!4fv6;ks!oR_`8tWtDo7jJNS8XCZhltg5^N3+ zp0=%oed1l)0HlXGR;Q!W*45@}_jC*KG*f{d=O8lF2pA27oR{=hH=ihD2QRjxvYCpc!3! zQ9x5FAC&J1H+KaVJGr;k&x5|2MFE4DDn5QW9h;bK_Ud)IE6a746o+ch9-@!=BK+3Y zSsUkeg{QX76EN_j)JfTv=Aw`yi|zxRk61##1oqou80L1l7kizR-UdzK-avg-knfy{Zj`bM@y#|JeskKQi)^*@jQz&Cb=I9glkv;a5UVB=% zJ?#^FdXGK5wfJnLZGFTgoxwW%g$X?gQzS)eXCy2*Zi{=|`bAVyS5(qgV^`FK4n=_Y zwdNm>ux%XQ6){r$+RC@i?@1cJJFz=yN>|bpIp&0IM2~HBuPwFPmimb;t;d$uT6k?1 zW*sMIp))f^(eRm**%!26=J%!8O)Yupcbj$lv+M#749S}yUhkRB1%nOm4V#)(5N`NO zxJfW_dpbgRH`?>T zt`;nuq-XEY3QwPc?~dE4+MC5j(2L$wQ;IO`PTHUXPNc``64*&~N8JZQSr#>}%u`Zfm_Y|uHnXWzu&R~iQP-;MD3K-| z{|G)m1{4EwX+%qoV*xbS`3|#Y3^vqyHIsLV=u-T5P{^fuV7-8vTXI+#?Q5B?v_lX_ zk*&ehLRj?(UW;@ub_Jca{y=c1HxO_w^710+FtW$FL5vug;1b_~LjeG;_$cdX7bQAsLsO-dR1BC9CYw`-UmkEFRcA z)-e~xGJ$d-Szs?BjC(=N1I!29NqaUYDEEWRrM}IHO)iiN1!x9ujptfwxEoRr;Cf1= zVoWYC*~jFQ#={GMi*Rq9F7^FB|fc zVCCYfAbS_=e!v01Kuvfhw6Y8`q+-Ep7UVln?{+Fb39XyZP`rkYsU$4fwzD zywgHlZqbs1Ufv`I-mi|{dJkjm)RlYp|2`x?Z@+`NDSX>>V~3cF*3Pz5a9gRl8nXjj z>c`s-FcwEoyz(?_rqoyCOlyR=RxNq8nVHrJa=Ci!wSO2{61}gEiJgXVSLK=j^#UAb zz7o=>Xzmh$q+5t9DTjzB441Lo%A`>42e*qx# zy&2RNKpWsWfGpf8F_~tnudVfRK8g5Qupo22xYBw z`M?|tLi9ZjgB9wmgQpl};$#ih2gEH#uI)o%t%cOO8mjy*k6iOjQ@#s#fQQvh-Lv$H ztW-W->`6BFKm}N@#vgrzg&I6eUMQ&USUP=r>~L^?h`WogH`g5BHnZD0 zw#z#91g-0hrVo05!zAtPT?(`SET`&8@a)hzEcqWTih2UrIKlvp+>v26I@X91Xueye?z<)96r2or=WxQW0OAlRE*$-PM}rKol{)(#VVatACXKNJ zRB?8a@*Z?;^vc->O?tgf_pLhmtNF%E9H|c0pC*}A+1O>DmmkM1@O?cy$OQj^-Z-Cm zyZp@J1+RtrL~C)nY8qtTHB+Fr$`|B^<$QpzDC$DeFwCTd1Vq#H`D^$_C^pfo3yEQ0 z3sPeh{3H~Z`30P&m6ua)sNE03 zlV~X+!j^{u24du4VCA+jgBsWt+#*iQ2jXywsWLfOA$DYC0ROG|&rZHk5} zMu2AUiHl7N*k|DV2;hA{6J+vaz!AXF{P`15pqt8$Cwm&-2-&-k^#J6C#dhSVJ4cl{ lvdEF6fA4OZ@aqsZUM>E$#W2LFXG1UPZouMQeR|fb{};L4A=3Z= delta 3555 zcma);dsNj`7RT?o_x^r7t_TWvy%@?%@rvF`qZE(hJ|EA{I0rZTqH zMD`J-)(cjXN%xI-%6e#IJF~xRHmn}Lz;CFFm*KzPEFu#U=%-N`VZ}mTZR%1}^}LMU zX7&SBWM$Z%8vnZqX47f&zzI(MVAf+yZ4v>znWC~2dl`{l?(r{lbFI3H2g-d5=;7={ zwoIRyz0}N}p$~Fu2fPaJcwoQ4&!h2yF#JV;cI4czZel_^K-+Tr{?d6%Za!<&S>F3B zxe2b#z##!6%-6oI(QLDR|JXO!@FT*>Y2o7#a1`hRgUXLV zZUXN1#b1Yt=oor1e<0gSGx9%RwKS;UpJ@?hQ=_fURBozcyVxJua<)Qw(o(0au=qkn zOV4{PD=f6BAd?!#+v(`|I7Kw}sH2nPGnkp!gf-!oy0CKD1EY-vS@iCN;j&>Qr58pr zD~&5GwkUxl%3TpdFBdwa)k(ZkOiFB$8W6jq1$YzvyHI0CX=>3D``zgq;eOE^a0kjO zHGYgP6cyOsz*<^>z2urW-ElW(*=Tm zw4-=5>!1t8-;H&M$<~ia8#BVxIwCwj!_qoBZ)ASD>M-ZUkm%XAmB$X8N2Y^V~`W%@rcTmQP`UB<5JM@zjHc1!O!?J8r3 zNSRT{qV>r$b~4sq@0zJuLk-m*;F_mWpEYMbD~-Y#?gdx@1?UZ!0T*VgLW%~$fe@f4 zP>a?WNSTm!0X?XfJ40~@W0pSIy}`l`>F4KVnb~pbt9_y3Q{ZvL=8t3vlrev(k|rGY z>x<`ChBFuat9r@+JG!fb8`8r^Kz#_v1Tv^_NlSm15RG1b%@^P`?g00AH8=MJw4h%G zCGli-UeD*JquD5mTA!;9L+A*av3?C3L+95wuskZO&xmuw{60qc5I`y7xx!YPLUlXi zslDF8=FpA$zEQK`D3@3UpE+c2$YgGMw4q4d2Sb0_)evA4$kEtua3Or<{7cb<`5Pf# zPjz4^*Z2vT-U3o-c4PcQC1?$+s9NCGp1^7*A-Di225R6ble`Sl?}3A~zj0(#6U1g< z3LGD$?;EGsWT}udETbtK@>q(#dBbBd>dzkE?Rd~R4bC#9a!I40-as$YAH9z6xIVHXyh&2**~0{IVCHHKo{^U~LADQDk#pS1{iQUkT$< zNCDspU>Q&ixY1fh)0$oEbA4@d-;e~pQ<#2z)fgcjulQ8O(OxYgX6dbaA7xggfm}H` zN&J7@QME(J(`YB!7tJ20J^PB-LAtrGGTev3fw}skSKbn3frjnh!q(EM{qf4v!d$6e z-M@w@&kAvczU08SR+dWM!^2oUZ8{vQu1ECq^y=YJ>Q;zL==$MwwoXqx@~M?QN33N? z)MnWKB>jLFD6^%AZKwK{0qLnQyaW_r^*bRo0nNaRfZUW%MXRAKt*+KM&!_V(DY4I^ zZ4V%aeFf616mvXRIVi+k`mE!_Ls=s|(>le5DVA@9!#e$1>!-{m=lmwRjL7r&wQ06DZyx8BK2%SZT07*7Gu0;dH;X_sfc1?{rn9iurX)5HG;9e7LM zc5<@Hw$eB6^<%U2u=jN{o2UQv!{cUQt7*%QKw_L`{$~U<)7x#yYA_4*O3IHQzqB!w~C&Ie}PUxt{)UiCLACEuanUsp0Kt0ygQ#+n;3O(jI+o z9_vF*=TcZFy>l*`?ICMNPFU}am2DBpXCjiRxTA=zq%9qJqEcVzNM;euBHB#dh(TGMe;I$1C!GGnqj%VI6R zf;bFvFB_wv-lECZ)6?Yllpu5?v|LmUx5rl#+#AJTT~Qw3r{s9RSM=KTK`|J~2=R-` zhfiV)8n)84>nUMZglMouyqoM_k5*&E+-@uStglyER5^ny)OjOWu?usuu69LQY;ybL zxxPpVH;1gcH@jP+CvjhhC#Z>s0e9!f!=cI}V1sn`2=E?4b=r)uMgA(EyUZirjk4>W z@RE~^gcJqz5-@DiQ9NzLbKVFO0d)H2eXN@<-^^6vg+rP?=(~1BO~v{Oh;^r`&B)mv za&!o0AN$in=Que1=4~VIC3zd+A_~3$N~ry(k?Kp(cTn`5 zoK<#+4gtebE}jW#P`9D?cY(hHt6`I+9W|VzdIzf(%A>4Id=LCy08YRrONBh4ICdQS p%dwjr+2Y8>FUOhM?(}Al=$G!)TYAqnv*<3MTN((z@{x30@-G~4H diff --git a/litellm/main.py b/litellm/main.py index 19c393e49..3ff22ecc3 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -540,16 +540,9 @@ def completion( ## LOGGING logging.pre_call(input=prompt, api_key=TOGETHER_AI_TOKEN) - if stream == True: - return together_ai_completion_streaming( - { - "model": model, - "prompt": prompt, - "request_type": "language-model-inference", - **optional_params, - }, - headers=headers, - ) + + print(f"TOGETHER_AI_TOKEN: {TOGETHER_AI_TOKEN}") + res = requests.post( endpoint, json={ @@ -560,6 +553,12 @@ def completion( }, headers=headers, ) + + if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True: + response = CustomStreamWrapper( + res.iter_lines(), model, custom_llm_provider="together_ai" + ) + return response ## LOGGING logging.post_call( input=prompt, api_key=TOGETHER_AI_TOKEN, original_response=res.text diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 25fa5c047..7b55c9869 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -9,13 +9,14 @@ sys.path.insert( ) # Adds the parent directory to the system path import litellm from litellm import completion -litellm.logging = True -litellm.set_verbose = True +litellm.logging = False +litellm.set_verbose = False score = 0 def logger_fn(model_call_object: dict): + return print(f"model call details: {model_call_object}") @@ -81,17 +82,91 @@ except: # # test on huggingface completion call # try: +# start_time = time.time() # response = completion( -# model="meta-llama/Llama-2-7b-chat-hf", -# messages=messages, -# custom_llm_provider="huggingface", -# custom_api_base="https://s7c7gytn18vnu4tw.us-east-1.aws.endpoints.huggingface.cloud", -# stream=True, -# logger_fn=logger_fn, +# model="gpt-3.5-turbo", messages=messages, stream=True, logger_fn=logger_fn # ) +# complete_response = "" # for chunk in response: +# chunk_time = time.time() +# print(f"time since initial request: {chunk_time - start_time:.2f}") # print(chunk["choices"][0]["delta"]) -# score += 1 +# complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else "" +# if complete_response == "": +# raise Exception("Empty response received") # except: # print(f"error occurred: {traceback.format_exc()}") # pass + +# test on together ai completion call +try: + start_time = time.time() + response = completion( + model="Replit-Code-3B", messages=messages, logger_fn=logger_fn, stream= True + ) + complete_response = "" + print(f"returned response object: {response}") + for chunk in response: + chunk_time = time.time() + print(f"time since initial request: {chunk_time - start_time:.2f}") + print(chunk["choices"][0]["delta"]) + complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else "" + if complete_response == "": + raise Exception("Empty response received") +except: + print(f"error occurred: {traceback.format_exc()}") + pass + + +# # test on azure completion call +# try: +# response = completion( +# model="azure/chatgpt-test", messages=messages, stream=True, logger_fn=logger_fn +# ) +# response = "" +# for chunk in response: +# chunk_time = time.time() +# print(f"time since initial request: {chunk_time - start_time:.2f}") +# print(chunk["choices"][0]["delta"]) +# response += chunk["choices"][0]["delta"] +# if response == "": +# raise Exception("Empty response received") +# except: +# print(f"error occurred: {traceback.format_exc()}") +# pass + + +# # test on anthropic completion call +# try: +# response = completion( +# model="claude-instant-1", messages=messages, stream=True, logger_fn=logger_fn +# ) +# response = "" +# for chunk in response: +# chunk_time = time.time() +# print(f"time since initial request: {chunk_time - start_time:.2f}") +# print(chunk["choices"][0]["delta"]) +# response += chunk["choices"][0]["delta"] +# if response == "": +# raise Exception("Empty response received") +# except: +# print(f"error occurred: {traceback.format_exc()}") +# pass + + +# # # test on huggingface completion call +# # try: +# # response = completion( +# # model="meta-llama/Llama-2-7b-chat-hf", +# # messages=messages, +# # custom_llm_provider="huggingface", +# # custom_api_base="https://s7c7gytn18vnu4tw.us-east-1.aws.endpoints.huggingface.cloud", +# # stream=True, +# # logger_fn=logger_fn, +# # ) +# # for chunk in response: +# # print(chunk["choices"][0]["delta"]) +# # score += 1 +# # except: +# # print(f"error occurred: {traceback.format_exc()}") +# # pass diff --git a/litellm/utils.py b/litellm/utils.py index 96d094088..348304dca 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -371,6 +371,8 @@ def client(original_function): ) if "logger_fn" in kwargs: user_logger_fn = kwargs["logger_fn"] + # LOG SUCCESS + crash_reporting(*args, **kwargs) except: # DO NOT BLOCK running the function because of this print_verbose(f"[Non-Blocking] {traceback.format_exc()}") pass @@ -444,26 +446,27 @@ def client(original_function): function_setup(*args, **kwargs) litellm_call_id = str(uuid.uuid4()) kwargs["litellm_call_id"] = litellm_call_id - # [OPTIONAL] CHECK CACHE start_time = datetime.datetime.now() + # [OPTIONAL] CHECK CACHE if (litellm.caching or litellm.caching_with_models) and ( cached_result := check_cache(*args, **kwargs)) is not None: result = cached_result - else: - # MODEL CALL - result = original_function(*args, **kwargs) + return result + # MODEL CALL + result = original_function(*args, **kwargs) + if "stream" in kwargs and kwargs["stream"] == True: + return result end_time = datetime.datetime.now() - # Add response to CACHE - if litellm.caching: + # [OPTIONAL] ADD TO CACHE + if (litellm.caching or litellm.caching_with_models): add_cache(result, *args, **kwargs) # LOG SUCCESS - crash_reporting(*args, **kwargs) - my_thread = threading.Thread( target=handle_success, args=(args, kwargs, result, start_time, end_time)) # don't interrupt execution of main thread my_thread.start() + # RETURN RESULT return result except Exception as e: @@ -1465,7 +1468,7 @@ class CustomStreamWrapper: if model in litellm.cohere_models: # cohere does not return an iterator, so we need to wrap it in one self.completion_stream = iter(completion_stream) - elif model == "together_ai": + elif custom_llm_provider == "together_ai": self.completion_stream = iter(completion_stream) else: self.completion_stream = completion_stream