From c63db48652324e613b3271f95a98595fdb70bce0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 20 Sep 2023 14:43:25 -0700 Subject: [PATCH] return all best of sequences --- litellm/__pycache__/__init__.cpython-311.pyc | Bin 10846 -> 10846 bytes litellm/__pycache__/utils.cpython-311.pyc | Bin 115128 -> 115168 bytes litellm/llms/huggingface_restapi.py | 34 +++++++++++++------ litellm/tests/test_completion.py | 16 +++++++++ litellm/utils.py | 3 +- 5 files changed, 41 insertions(+), 12 deletions(-) diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index 709d54116fadc7b54e0c9f47acef805519c4464e..f98964ab77456264475fe0ac626208de302782df 100644 GIT binary patch delta 19 ZcmcZ?axa8yIWI340}yy5ZsZEl0suWH1snhX delta 19 ZcmcZ?axa8yIWI340}wn;*vJ*41pq+A1-1YH diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index bd50ebc532cd0fd690a9fd0fe6d670df29d43697..8b3ac9683fe03befb508f1b9efc1f512bdae0f39 100644 GIT binary patch delta 4983 zcmai13wTu35zg%H-8?oSfjoGIKms940-Nv_1PK)sk+dL|Aed#%t|1Sy%h`=e41%B% zc?CIyqn~|%BE%{vE^<|BZKX=Iib7Rz1pyx|R0}BhLWItLHchDQXO}NOb7$tv%sDf6 z&fMpg#U5K0+i7c3QbG*>;=K-^HDyPq`&ZH*!Vj)&rKI|X6mLqk1wv6HkFOjg1I-DY z6B=v{Y2MCPGS}OMsZ@D)a}K4+w^omwl3s1q+D1bhOERF-0G+Pw+TAoSm3iGkt+mz0 zv6migUP>DbVKf*l?J1|NS(w~mFP?e5WK(lGdpf;FQKnqpG@tRSrplyDZOKymC0pt( ztJfA&7DqFz`&vtFm|1<=j1TwrZAiW?M!4eT(`$=l{sTi(v)e>Vy*UsCZQg!z!2?Cq zU-n&>FZVr=;mB>X?eY$gr&*E*2{~nPQ+iTVMzv=el0z|aQ$?Cww$4QZ`8-nPm8&x3 zF*&$@#}05J`Fi*9-VSbbUoM$0%hucF`sNgQlgx`B)E3t=Ic-(G+$BrLbgcLZJKDO# zYi-YhgMDU6w%&EHmL+<-v{^QE?p!!gxng}T70Mm!Z|_>9wa7c9-2$64x z#opozhTQ&JXGG$?es|C_KTzY7CpHwy?&_$8YQx_sC(W%SuG1>OM8G6~7ciF#O-uq@ zgU52fF9FvAW&tVy_d%r^R2itx0aNif9WX;Kd1!1x9g|mckq7@6-uuuilTj=D9sJfq5QL3bqe0FI={z znDS~s&j$=+h;;S`e6_+~7xD@Be81OM6BM-&2IStYSCEj@l4cKr5t2Py#u>|*91|{U zdBd>Z0q#OsxV<19z9UJ+lZO-)xr=`>**#_AEU`!~*xtP>+((ip6yGp$#>A3#S-sq} zeK7gLZ*Pw!dn4@JFONKxV{Bq_Vc2+juGM&y$vxq!u(Z%2dGXnOv{}BnYZU!X+RXxE zCktN>k2GC2$4abT1(0&V?o`KGkm~_WvT651<3%RVg_C}N!9g#}gw~(iUw~}4bhq|S z*$c82@D^a7T;00gegNcQ>3yv@>t!Y*orAvGkXYyr2~WAtEj)ZsddWUgUp!3R<&oDC zC^3BEwVW>Wrc69sY<~#`_R1R$FQX6S!NcVgkz1&>J@Z|pN z#k|IuATEI$!wouGuq$9m8-4-QBEWV4L8S)NA$jO*F}*K)oXZ&;!Rqe;7XS+K2Lbv$ zXaVmbJX)~YDyz?3wb%x-6wH?aI#IFU^~Yl_U;v=~{Dd7+ePr7I3AP$>)c7mh!MgHt zK2dQJe)P%x4lA|qe}MNH;B!C~1T74);x`~q$-=LuWhKFKBA^o>84w3(fX*ek@vFh) z2*34JZmiLb*&D;>FZ@8XMSgqnJK7xn^3oZi7WFYv->!}9#2FXe*Ah6q8*qo}X`w9Y zrN&q&)9Aw{c9}O@XepK*vr>(b!>m`#X*OC$hLh<7=KBUUSseqJ7Wsez6^y6c|6pv6 zrvk&p!if29Jhj@=b>06S4)wVn26aF^(}}L_wHx#pz(#Co69Xq5(TWi&BblzSkA|0c zbz?G(GOlI%dvjef#o1Ddz#jq_%23}=?MtE_J#`;(*Hi{WVi;t@0V4n-)zMU%earuh z1@7j+|3ql&T#N&C6<|D|7@*VJ=RbqBMXgVQT!%VE{VI*Jjq8~)(A<+oOS*Tz6H79o zF%}fMSAFh`d6hwRq=>TAor9^G8r9ziQaiJ=-Cf~)6NA2T;S0HM^VHP& zMD#S&u^Cj50I>i(Zf4v_V=YNIY^NbcVqZ(u+?jMWZ|Sj_^lLWwvr?Kse^&RE(r`xx z{N+LGUiE4zjk52JXL`A^%_2XmhGx-Jdfj|^7X5%=Nr#%|cI+cGjt+v0*zJeij+(#Ma zJ2%rsN`X1C2G9gJpgwZb%r5OgvPCaqJ)ow{rP=g%wRJ9ClL9B=0YF=X7w1xzeH|F< zRgQi#W>t7=uEU{@S2m2ESrzYB`PWtF1;2!qS~A zdfeI;e$;eJIbCNz1=Ho&!vHhOOZ|yftBF49+jlc7aRz>5M!)#5CAO&LKAJ&atHVC3 zv$R+g-9+6~^{uosOP|7{aKe8)tN>t6JOPqWWX%)?1C*I6!`VbS1~Y&N?+aNNyTs?hewy_#!<5&O!apfRz7Y>Z3E_py-Ua5Hj7+ z>Osv`dlr#v@rzJ~SJ4dkHA5uM?M4J{x7Y;U^MI`o#PS3gs0#&!KG|*HKLgNrjc!>_ z;8Bm3emDsL765bydpe7zdpg}qm5c9#?fU^M0qt9gZUlP}Izk<~jz-77 z69TtpME-oJgW>k4wMWv8=7xQx zHrneTAS!5gWgkR>#2o@pV(#^OYf_R&D1CTIVkI{phw;haqu=(-%LGNzqy%f<4Dxp zOU+bae9ZJ8&HN42XwW3J^I;l6pQ_IurZL9nEcu6-|0~*Wu^+~Y7Sr-uiYMb7KXg?o z8|gDr>cU1g@-1_6%z2x*BlO4D9+%fCtT@dONiO%)x;FgUc&ca1IDU?&2;;KUQ9YkWaCWEGmU$cuiI!fX+MAo=IfhjC$W*m zTc|k!uC@P3X78=^iNQmobsJqnXVjE!G&&c15C~fb>U2aWRDI{40oTFblOj<)v5j2z zFTnUl9ot4%7@e74VfNlmE-O8*N_J5H=(MN@i9AH8U*pB9aR{UKlNZPB}CBQmAm>7y{DuVigae?ID@>2ZFX%G=2=dY!po zCtXEUs&+g}JI8*EHBZ9C#{m5upl5zP>7Rg9&+w-}=|NtP`FfiFr+Q=;U)Oq?9IKA) z;wFHPYJpE!u?c{0OfegV@j)r@sVF|e;|0KZ0492|A5>QQv$O&5m5JJQMR@SZn6}f!1381(p^njsO4v delta 4828 zcmb7H3wTXe7S6i4=O($42T2eiFNEurhp5CO8udzBgv1Cb5xG}{NaEahRftxFdWJ3C zOieYdqMcG)9A(t>mFaDeDZFQ^&+8OxlGTQB0x2%o=W>f^*gQuDx)C6VI_*R74L;LVUMKh+tK-rr!)k5(?m}&{C+Fatx z_QM`}aq^}8M&v#ZzYAsf;n#TCks-l3_ z%BHF0I@@d`W>lwBWJ7dXvuaa`i`h{Pu6o#`MSb?Yu|h2SKP>q_V9B2?MT%4r(X@%Zc|(?(DneIff9={h>Lw1X zoz)^lEDeXUL=2XG@h_qzVJ{(TacY9AKkYMboZV7JERLaiTz z+dse9Q8qa*zqo92zCIf40>C6!)#h}vOaVPr9N*H~QVg<0e7_|%qKrv*;KY2Vquf!F zH?7p}D0b@8A=kwtThp{fO!jhBZ9S%0W`a9Q zSJxdAytP^;-*K&R2@{o6!d%`E0u7AlAGu=AK{n?5nFHoK096DHy!LpE19Sjf#9|Y`45`{N<^F5tx&4bv z3-g>a3JM%fr~VDJ)W*GrmCEg=dUANZ7ep!AAg1pc`-8~W>>m$) z0w7Ty51~%MNuav}k^w!WPblT4{HAZ1e**p{8J9z8L3r>S*ouHgUr;mU+-@YON(RMKl>g)4KMF@xGBuv&(rmdWo<5*l z`BVa>`KvV5HENYd6KHhseSLK{s*J=$YC<$urX|sIikCZ*s532)XOk#`UEEHh*kS58 z^@fXjfO4kVR@rI_UUNVQAQaFNFcrb7s~G9Jwc^rzyHhWdle<%!9;!fU<5YQWW5FE& z_Fw;zDQ-O#VKx|7Aj*>ayVD@+6JP*F7}jJ83!@Ynox#!O8+|h9syFSEorls+Z8~cd z85f3<-OL(;Mp9ZRDrYn4Jz)SbaT`pR>qb(d#Zwyj(MW2Upvq9nVz*CO{v1TDqdZPC z#}Lhbs;%`r7y&FZT4d24qELA}n_|OjIBRzkr=x(+Mjp5R42SMXQ~HjgE>?(jz*Zw| z6!kR)U~_Gn-V+HfkV{6>0D4m%A5DK?gHv;;KfNor=THy-Q26Tzt%dSx4y9NQ`!T&p zb{RvZth!+g4W%Q-r7@IA7S<-@r&kB%E}<NKf$m(k{Hk|DSf0fv3g}V# z&WNzn2rQoFpqQAKSj7AGAB)vFg9Uw)e8WN6mRr0_&SDvDrw%e<5@nb+o8*M?)KZE` zv^z@WSqmrp^YYwcFs#1}*d!+vQV#vgI8;b+6sDA;QNIUO?oOs*7K@d=WW!6moL5Ab z!c}x8&#_ktao52^wOlZTrct20I)$#%8+eM6_omV?3Ns!nrX=qtlou7v1N#rB$}?Vy z?j~x?$t#^vUN)oLb8D#kS2=Cs*%B1@!D1kvQSwYF*=UpWF5@nVFtW?&D-%`AR3}At zS-=`@bBSYix#ul?zaZFRo7oe;dZ|{+g-*(&p2kHdo%Tz_c%n@H~UtH!lwnSC&k+{}q`iMAtO zRr%4-9?HN+Ysr3sl2z6I%%hwe#Pj#ipFYmN>;63n z_D|--8r8KFN$v5`#qD(kE52iJ2NmR(sKaYUAh09Xsiu7Qx!JhQ5m&c#(eQqT~U-%o;Sz zwpKr{`paI*A9yo8ua*~4H+7J|K&tBfs;sF`JXJk4boY^&yZLlGjrF^!AJOmRnLV_- z?**)R9VRXU)U>W9b2X2D4XK*6uYgkHwi?FO%zarN-OKM@HPiK$z6KRl;@eioXQqxX zKfM;98VetIIzHAspIG@EOvw|){-Bq$8PIwoIz51=bR9S6*SYS7WYhcX5l&)A`c z2z3`6HKj8X2UN#d(9x|r`cc1MH`P3 1: + if "details" in completion_response[0] and "best_of_sequences" in completion_response[0]["details"]: + choices_list = [] + for idx, item in enumerate(completion_response[0]["details"]["best_of_sequences"]): + sum_logprob = 0 + for token in item["tokens"]: + sum_logprob += token["logprob"] + message_obj = Message(content=item["generated_text"], logprobs=sum_logprob) + choice_obj = Choices(finish_reason=item["finish_reason"], index=idx, message=message_obj) + choices_list.append(choice_obj) + model_response["choices"] = choices_list + else: + model_response["choices"][0]["message"][ + "content" + ] = completion_response[0]["generated_text"] + ## GETTING LOGPROBS + FINISH REASON + if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]: + model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"] + sum_logprob = 0 + for token in completion_response[0]["details"]["tokens"]: + sum_logprob += token["logprob"] + model_response["choices"][0]["message"]["logprobs"] = sum_logprob else: model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"] ## CALCULATING USAGE diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 10b527dc0..0e995cd9c 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -135,6 +135,22 @@ def test_completion_with_litellm_call_id(): # test_completion_hf_api() +# def test_completion_hf_api_best_of(): +# # failing on circle ci commenting out +# try: +# user_message = "write some code to find the sum of two numbers" +# messages = [{ "content": user_message,"role": "user"}] +# api_base = "https://a8l9e3ucxinyl3oj.us-east-1.aws.endpoints.huggingface.cloud" +# response = completion(model="huggingface/meta-llama/Llama-2-7b-chat-hf", messages=messages, api_base=api_base, n=2) +# # Add any assertions here to check the response +# print(response) +# except Exception as e: +# if "loading" in str(e): +# pass +# pytest.fail(f"Error occurred: {e}") + +# test_completion_hf_api_best_of() + # def test_completion_hf_deployed_api(): # try: # user_message = "There's a llama in my garden 😱 What should I do?" diff --git a/litellm/utils.py b/litellm/utils.py index ec525f536..17bf38f2d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -902,7 +902,8 @@ def get_optional_params( # use the openai defaults if top_p != 1: optional_params["top_p"] = top_p if n != 1: - optional_params["n"] = n + optional_params["best_of"] = n + optional_params["do_sample"] = True # need to sample if you want best of for hf inference endpoints if stream: optional_params["stream"] = stream if stop != None: