From 2384806cfd3a3351e93ee0f0a3390f502facc455 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 4 Sep 2023 14:48:16 -0700 Subject: [PATCH] adding first-party + custom prompt templates for huggingface --- litellm/__init__.py | 5 +- litellm/__pycache__/__init__.cpython-311.pyc | Bin 7757 -> 7843 bytes litellm/__pycache__/main.cpython-311.pyc | Bin 28706 -> 29049 bytes litellm/__pycache__/utils.cpython-311.pyc | Bin 83621 -> 84544 bytes .../factory.py | 109 ++++++++++++++++++ litellm/llms/huggingface_restapi.py | 21 ++-- litellm/main.py | 4 +- litellm/tests/test_hf_prompt_templates.py | 43 +++++++ litellm/utils.py | 22 +++- pyproject.toml | 2 +- 10 files changed, 186 insertions(+), 20 deletions(-) create mode 100644 litellm/llms/huggingface_model_prompt_templates/factory.py create mode 100644 litellm/tests/test_hf_prompt_templates.py diff --git a/litellm/__init__.py b/litellm/__init__.py index a4ed950a8..f3b51947c 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -47,7 +47,7 @@ def get_model_cost_map(): print("Error occurred:", e) return None model_cost = get_model_cost_map() - +custom_prompt_dict = {} ####### THREAD-SPECIFIC DATA ################### class MyLocal(threading.local): def __init__(self): @@ -224,7 +224,8 @@ from .utils import ( acreate, get_model_list, completion_with_split_tests, - get_max_tokens + get_max_tokens, + register_prompt_template ) from .main import * # type: ignore from .integrations import * diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index e68c4da03434062d30b612cd17ca95e969019102..684c6587ed7e0d07af8ea2eb2abd7875e6fe5e9b 100644 GIT binary patch delta 788 zcmc(a&rcIU6vt;aRkYewN-GZ|3dpx59^_z8?tq39jbD z-z+xtMmH$?GWu2vS_H;4vP4*9kFcwpJRO&Z+F&!T5gz}D=Rfr>sRa{kc#ps&qcDY; z#w0K^m|0BHBrN)#LJA#2r_tBYbLbIt27Mi!Mc+W@*iANSYLI8Q*llzH_s4L*h#qfZ zF^@$YT|!Tw7toXFGEedJOZ+50rv8=ezBizB6Q#kYP-BIqTL)`|q2Df)zZ%3RQVb*rLc)^oFSYS*)DS6(yCN6_UAByf{}q-L<9=@M8@6`lM|kj% o1iQX=_{vw>AL!S*7=~w#^+Gj)S})Pk)O$2Z+a%3BCuyq23hK^F`z2|y%n(w}SV&R?H|e4a zai_CXTqx8TeK&3m2|@XKZof-huad=6gAO=fL?`-gSeY0)b(MvljU& zzKy*OhIr$1Ymm>*5AnLLWu%!ZElE z$KeV*3D3e)a8;&d<^?`WW-W7;dCP*OYgx1`S(Ys;hRd>Qhgn0K!_sLiz4>>k%*niT zWkptHK^Eoy0+Dz(Ot(a`#p$+4i5ah^YEf8WS9eKI3?&o${Uh6_y~oyB z?}=KV1+_?viilL5>Z(Bv)ug8Xa;P2T$MuhsO0SNt@POVO{i?s5>hmSPbUMlXjWav^ z+&z5p-qD`y_S(-{dxzJ1ZLiyTfKuNNx%y4${2XEE)D0R2O@k$aWrM6i%E0{gU&mav zCe}S!>Ab?ok?wo{d&qN~!4d!I*bW~#G2#r^0X71Iu#S$e^n)Dj|HIjTzdfwu;pe{r DE`+y( diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index 4e3285a90e09b54c4a536b6aa0d29e799a8e7456..809ab0689509481db1fd6dd0a21dd705f896112a 100644 GIT binary patch delta 4164 zcmai14R9O96~2{Z+4BFtY|B5g9oso8{*4`5u@nEqc76;FG$sKBP|muuEc;Fpt8-E( z5>A*Tg&A5xw;ej2Fd?Ocp_l>@9iX&@z_gvf6p|JO_hx8Q+I0A7sWX$5DV>(ScV{_D zrfp8XZ{NQ6_PyQrcHf@AcY*xlLz4YuR#rL(PwMu+i;G`>IJ<;=^Wlci<|~yNPYvP3 zR54SoRaSXc>2#J{r__7ubvj#aP*!_Z1I>v<<~=;LIdY@o^0)|R;v{~04cf&=%soXo z?rC_9$I}G5d2)$T>M2#qJY|Z*<4~GC%}TkaT(_5>4Z;>ri*Cp^-O3KFD((PD!C}tR zCgp?mLb32#n|{vEiA6A9{1uA?blOEbyFiLMi$pS+jXd}jf|LMl^pJy8iHM_1kLQVH zQuCxqbR14)QB%d*wPHDV><}veJD2@dUhfwyRDuooOw$Jc-lPo0dX?eW2Es*q z9vflj^9$LdwOQ=7S{vJ%pGCH?Gqn$ptxTw|W8qa6vQ5{>S(BX>>oalDzUfhRV%1xn zi`O~4Daql-Ne(6XW@Rka|9_mVW$E=hb%!pp9c-U1Fa`f`o%LH-prM-F${ymo*y|1X z)z^CRTUM6s!U588)~2HX^fk zRu2$`8E8?OoLZ8(F1@U?p{Hmp9wV%2uG#9PQc z>)C{7Na5}5w#HGgezLLJMh)wcGDaN57W!K;fZPw#TQ{Ft5+xPR3g9BQ93l+Cq3X z(ppcNn9}Olx}vM43E_ZOoeN7mOdZT<)-Z)-sk%AdWZ0iYL8@;m6!1wA&7w+QRcRP~ z%ItirT0VoMf)E6V;P!PosS6qRFrlqBcM_^JrPyymm}TE;>&hAfk|)M`o7rO9sMS~t zW$?hjHi;gd|Js-S%7T zEAd@k@N)22&)S_CD>CY1$Jg$;;leUfV6DpwVQV{)9JFDy6lsG7n z-hQ%!z0f;8j~67ZL+#Y%YC`qYt!#cREm1=>B2oQ37bKOhgpiJ)pCL15!|>3#j=KKx8Wb6E<#kCCZEwMF{=y zMh@6H?jUz%C*d;lALxFt`{2-$Nsgc2_)w1ZJ#pq;@!t5Ju_KmEqrc8-IhWOPv7-Ln zinhgK$6{&qVpY=zm99l+3lNKCwI2YdymqM|y)yd}2XGnT{;{P-uB86{Ee~wD5dZKi zMdtZ!9o+Y~Rt`Y(rg%2Gg+p+_o(b+cN`>LFd3W7f_pe2Rn7v zNc)iP^vt%7^nuYL^Ocnsstk zg%fk#O>n})lXiC2!}sAW0A=u(==J9ar8vqW2O_Fwjm%Cc0aeeYX2F7~S%q*|3W{() z8oN-zO%oUR zd1Of6p(#cvGkl@oWWetY3+fcj!7)w*3?U>E4}q3Y5vV`1vVYl^0eeaVcz=*NrgxR= z+jl*E1IJ=kH8ZT1euCXUvu{jyIN_gMj#gs-EJ8KvSbU-I98JOgY4+jt-6mSZHq2BO z7NamzmX%2Ja=n8hADSs9CG6--&AfgxGX0Y3UHJq$iZ&bw#{D3N1ko#o;ELEw*PykO zNnjZ#h~l}`kAt2yiXk)th08HM+Hg5~#J@=Eg_=AoOB-o7Na2q&N6eOZqx@Zx>o~_< z^A~?zSw~Fc8SIN-@22q#lD@;lJ+(|SN$SRlI?_d-l6{p)(v$d?_&oI} zo1+y5i#$caVo!0f#8VP%@H7NVJ*7h5I^=EgY!Z4>O%sitdT*zjn6tQ)AIb}DS!s+=+8LGmcQJZqx9J{(Qzu42JtjC${ ziVLwrDM##7*nlKg#9uI_viKj&(StUsiNIHY1!9p869MX8|9bTRtQzd@O*hHT&d_~v+#Ds<*uh4Q%*^5A-97TXO=Rr&BqRVsW@`E9mW(7(T- zfc1;s<@#*jl zo;ZkBUd}Ux7QSKU5d65NoE-*NZAaet%~3J4Lw^{Z(FI#Kj#?{XP9*?o@?C5KhGlor zWNc@E&6g$IxUyCz>5sy*a!sl>X(C!Q2VTr6a4EsWkUy+wGtO8}r1mD?4hZ9+Zy4{> zABXME0ytl5WgfV%b|CY~6*H1FaTMOK-On@_t}M+AeT8caFI$AgG^d6CNr{l<96YV- z%Bh(>!hy$T_JXUrh;iJ&c3W(Z5><2grmC0mN^iPb34n50ZSWS&0_?e*PO zlzGSmE{DT(WpMfAa@akvP=8P6(`)Nd&OmEJW8SaVC>Yt!SoQogoNg%3KX=f`4u%y~ z^ZK-aFX|2YBD@3U8=gAci8J^%1UZd2QNNjBGeH|cJAn}DqTUvQod`y*KOBr`svZc3 zyndgic>~JhWY`u=^qbuRa}7Wn4O{F>Dcz+tP(P*_6<)6C=ex9uD}`sA197 zGdd5G%}E$()=MX-s}Tedq7-*77q2Bk1Qt6hb5yM4{1lDr33RAx>Bu~QU702J+W>)< z{kFtj;MtbS^+!qW7y+GuVOK-`u!2rX?9B)NZYg0Gpm1|Vo$#fPbfqLY<6fFNN06im z)Xj&TtNvZy@UrvZ)y+rKS9yE{8d`^Mxv_K$?sa)%@Mde%PBXtiHrLiU|V?H?9_TsG#k0co$Zy&hjgz|5#O`sJz5R43We@nC1eYoGdzi$u@b+~V-CjHu) z8ZbI~p|d|pP|0lu!$6jiXmo@qw5a02nI?!BX>0q)H8lH zLir}5O+I1Rv@m8~l;@CKiWbmSO$&LpmsE6Cn7oga{+56u*sz7CbgVdr1^p4^vqV2n@G`-51S8j{ zsgZFkg%LQntu^}+F?uIM6}9il{RPy1BK5UW`&ru#JGGy0+qaI|MR$(7(cNE% zy`9pIHbLrZ-zAJL_IIYa_Z68J_nC6t`yALhtT=|h9-C*k!^PtrGVKZrIUOv=mzc!A zMle#is`F`;M=dyV;m2#?5r{HKgFhmsl*EI z83p&?0iTzavbwBkZ{<0^8^5G{X!45uw?(I;nANh$u<)>^@=3BVjDQq#BpeaPi>?wz zMf4Ffib^H4ny{En7kKii;y+Q658fIZ!RNv7kOK~le}u2}?m$10a3P*CEyGw5@$s?fYWza)q|=K39!eaB zCXbbh@a^J@G*_WXX520RI}QH^Hzr$ME)p%H@$U&-2w086k!fzF@d8w6CwCi#cnL%| zjZcYZ)u_*}(gRAzhYruD2^oUDT%UOXk#YoSwJ47JxD&2x6|*9JGR9QhyZQjRK!%D5 z5_Mh+`xLJd#!`Qp*O4ykD3TIs#woc)&WNMsU8V0oR=)NQ;S zk!@&nxyd>PpOy|{1ico<-_Zd JM=j*t{{mWO`-%Vn diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index 309ff5b2e133cb1ff0d779f9fe93e5d4c586e3b3..7a8c10db65d656e51173cd88c10ea4fa64786bbf 100644 GIT binary patch delta 4294 zcmai13viQF7S2iXCv8I8(D%C~ElSz~Z2>K>mh!5|s}(`0NE-5|rH?#%|CB-vR0NfE zMS-jPSH-a)E1-@88!_k#OPA-a$d-z;H9E71sN;-EeJm_1;GT1Xpq+6x(|-TC=Y3Dk z`R~miH%W(|mST6s#F#YrsjR#t?`hf+n+kz_w^2Vlst8VzP3qTB2U`@j%IoK22E{? zJOj}!*UB@|oky<1l~=wC(I?MBY}IP2H1gdWH1h11wHnQ<_#gSyGIAcuqk&-(JcM<@3ShJ3{ z#H0WOhRW8httIGN<{53tH^ApB7xqRM_L~#Cq+W9t{tf7*7|XyAjV1ksh27(xUAk>)U!1Ks&elJ+d~;lH zZt)pQNq<~=m!mH(r#CLAzhu(pxbB!UmQnq2iCyq`$HtES!eV3@&sauW32GnA)=76@ z*qY$*wOv}M=Z79T4R`WQ>&jpj5g_C!ri3FZC=s@X|``LwJi zxQ}2FAF}B@YF^xwViep(pyrH+G&4Iv+63!^B~Ly-6c+OLUaU0CBBS{{v3or@c~`d+ zmh-FK^T5t$Z@(AT@b2yJ7d4T}muMKaI9E9QcAr=A&vhwOTGmW5k<1d241aOQup~yz z-L!5cY7bE!{=tqQtmnUbX)5#uPrdYVylo{3Mfz)qT1cQLXeU@j&_S?_KmTT)c`cE4 zgpO=py*+F8Ft(034js#3-Y%Re>1^W0#5J* zfkJAc2r7TQJ;%(>6N8{5nD^GF2E#VgGD^WAhX+6#Bk}7!JzBFA?D=aCz(l_O=(Dhc zXC2FuUO~0=V8yZZAnir6fS)*CRTP$(x5cG|<9y6zcSm}_uDE<*3pK?rIZi$QHjo%U z1>>qX=K*a%C&MPu4UOxKRVNOt)-eL3m~!zmW0k}0YjNeaQHM-c zc)jipYf*$Lb~Cy2L94PFP2JWB)=7azw_Wi&m_JZ{gISbobtnq@c6j`QG;Z-xsZDXQ zL2|)b6NxpbWD|0&rWAWu6r-q1tjCcgMonbeEqtl%@JPV3DFi2D}7+uU=e z!18Az-yt|iaER|dGrRK$lD1?<@!3%VF|(22w0wgizDJapt>Z+UAUH|zHv%!PKcHS% zYV-MAjEQ{zMTX}HC?)m@QMb}JMsheEM>3KA`$V1P>-wfh{Yb6~9`D;9V}6Oa7J^;; z(Q_Pj2Gc+31K7#GKK}t6;72}Ph;-ToE50=+T}Vy;NT^e&c!C6iV&a|To(t*Y#J0p% ze=9{>RHO|)Bcn59et{pnFw^u6tqHn=F`u0_K@b1EeU8;FUWd$1qwcN7 zqkdxFk`ID+gRaYO0X)ZZu2dQ#JK^bQC??49rGRt?GXPHffWL>SjG;gm!A71}e!F%O=PnsHGgK%_6?^!WSDx zqDUSn-=n+XkGT2jRU@RS^Yu`n%{7M5^CCzkr(v@b-+nDFf1}rt2BX!!Sg3}Xp}aWQXoQLCYsp~kJV(|2Cs|bzOrm_#BeNhH zXf(>hdJlfBEP*yUoIni7NTkA1c00-J_Q<%2uc??+rV@-J7()^6BB!;$R zK@z|Z>aJ`Uoj{`~-rvr4nO&j%om9Wfh7r)9#^=C1X#v`GgU|gYih37i6J3Ol+vm_}#JlxrjA|%0A>OJ`|+VB^m zd(_ABAwSxR!NbtRs?;~~VY1}JIisHYj_*YI|}&e zzo!?AOMzX4!TW;ZXE)Cge7Vv4J8Tdt3OhtG`_$!yP<)#>S1kYU_)rfNLRN7v*?mI0 zy>&ns*a}QX_>UAv^m!xknpCL>3Z!PFKM0L2f?5eys85x^ZtUfx)d6uHZ;B% zS~3PUN^n{|Uj}zTyP7s0CK}eFBrPgbJ06y47e{(^-n@C(u=i5Dg#?QbY^#}Fq$GRA zaWDEo>|C6g4^TGZ)h-5a7v)t=JaK709&Miki7F0l%i6f0(G|HM?H5DMG zIkk{6fM?GjoFj#rrsF7>^>vr12CXu9_wpOFbP@=ux-Bvrco$d@D-#zEt zbIv{gxsOvj9j|Y3#Jv$48)cEdi$fpigMHiMl1V=_EXmeuwmxhFm6z*tbhrHfZC++e z=Fi#eR(&pCW?Mw__`9~jG@swLWwy@OCv6~al|FfcrF@xQC~>*IKrfQhs`V*yRHILo zSgSuHvCe9lZqcW0u;|kdSS^+x%fIliyk4In^(!JR`m_2>IhW}ravac0B?hZ3l@@)3 zK4Pg|pCxq-dYP22)M>88z0f#l&srXH!(p)qOO>T1p*fBxtVuE+)Y*s@^6q*O8`B)8 z_gbz(v0Ppj?8kE&lK7<+xi+f3!5^-eKqfC>EwqYvu(SMrU6^5nhi>3RO@ULSdP>?@@TK5O4RYUca*y`NPDi_c`hP`vLu-k_(xE)Xp72Lj$o zKU)T!%4m_wh_|&3PFw-uL5%Cc+QDl0%{GJ9@h$s{=#25p{Reux8=zEcXa<`Hume^C z8Ud>TOZbn>46OyUf?qI)6u$tvRzk?;uhm(j^zV*h>mXeZn1X<70m|`Sum-?Bzy_Gi z0rT*@V}e5q-n26{g-wlzVi-QAKJO~OsBQSZk#>S zNtDkwo!vuw`QUQ{9PQF9+L&@~JvrWz)Wt8IpKjD%+-gr=Du*Gh(jWBKt_&BzCnanV zPrNcCwGuQA@G?@^DnWjnHVBUL)mO%A9iT7qV^_vGelF?PjQdwE+oKObdKh5v%hyK4 zzYTf>a1>zji0j{KeMeGv5<`4;0-&lN4vyj5hcHG}PDo z8B_UwfQcUhkP^EE_9T60IEVASpi}9;2lfko>-uEJRY~U=LvI|3jot@oJm3Jo{Sl`& zqv_*c5w-D>+t=w3&*+*j`RcB8xeZ?JN=~_^%n_!SR%wv)C@oqwtNNfhe9@yJ6vNj% zno8IBnMXgNJ-ZkcS#N2(j?$RuJ@B7_NO8OF)42GLHA&;)bEFjdto)}XeZ_2|*|bNr z5tUfGH1jUeEYh$r_KBEoqZaF*H1nK|meALezw8eT?6L*~CGVPiSR|r?EA0Qd92;tCcVPAITi={vn z&P3|lZU6KK=>V}fn)-{koHWV$m1auM3n&>*L(#svn)+&gu&S;$z@$L-p{R(Y{6wVp z#2aH!4Hs`l(l^Ozei=#6P@2jr6LxB1B}}p~90?eXmLEH4C^b`c z(c>)KG2ybHxKGuFeIO!wlbeQ%QN5{*GR*b8X|t2^MNA5%x86i`e~np%fIQ@z65bH1 zi?XFYWUrNVDe@USXZK9MmrHXi>Bxl&`By)_5*4Sdvtg?$HVA>>cav&vBocPT^8b>uEA&sia zdqLA1&=;VxM1o8uqw-KGJc+InL_{_fpmqW}KfL=QAp^W`UP~hzY2ZHMhwk+?|2Bvc ziT*5{>6Ft48&aL(zDAwh#R5A;aXJmBrGlkXsbi*`*I=GaCyzZ_wX6E1>RMK}G1`p9PHd*u?P7@r9<#3a^U^oSAnu_ZZS28KX;gRxo^NURSu2nPS^rjI!`*KQd zO`Yzq_NXwPpg#~~k8cxs#NiYVY04zPOvJn-PG?i@P<6Y+|G!H_BXv{0q zQU2Aj7``f=hs-DZIW)qdOMckgnM3mI@QS~Vp@XuScjr<5Y-qyOJBt$HF%{BQdaE1! zf$&9?4+FKk)pP0$tTrL}Tr9^{C7*k$>#?uoed`q&W982AiITBY>8O>H-!qSorOggH zFNz9itg{~erdW8w*8XC50i|kz2T>6fjGIyODWRa(bP* z{#9qFt5V&eOOTg3>Qy!ekYO2Q>N#}Mte->`L?4PPlc{l>AKIs$vsI9(ClX#q?8)Ot zJ&o!iTOrmI(m3rL_|L>Ug;bV^A~3wLSstJY*-jIKi>Oa4UOVgvQfLRmQzm@L[INST]" + elif message["role"] == "user": + prompt += message["content"] + "[/INST]" + return prompt + +def llama_2_pt(messages): + return " ".join(message["content"] for message in messages) + +# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110 +def falcon_instruct_pt(messages): + prompt = "" + for message in messages: + if message["role"] == "system": + prompt += messages["content"] + else: + prompt += message['role']+":"+ message["content"].replace("\r\n", "\n").replace("\n\n", "\n") + prompt += "\n\n" + + +# MPT prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110 +def mpt_chat_pt(messages): + prompt = "" + for message in messages: + if message["role"] == "system": + prompt += "<|im_start|>system" + message["content"] + "<|im_end|>" + "\n" + elif message["role"] == "assistant": + prompt += "<|im_start|>assistant" + message["content"] + "<|im_end|>" + "\n" + elif message["role"] == "user": + prompt += "<|im_start|>user" + message["content"] + "<|im_end|>" + "\n" + return prompt + +# WizardCoder prompt template - https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0#prompt-format +def wizardcoder_pt(messages): + prompt = "" + for message in messages: + if message["role"] == "system": + prompt += message["content"] + "\n\n" + elif message["role"] == "user": # map to 'Instruction' + prompt += "### Instruction:\n" + message["content"] + "\n\n" + elif message["role"] == "assistant": # map to 'Response' + prompt += "### Response:\n" + message["content"] + "\n\n" + return prompt + +# Phind-CodeLlama prompt template - https://huggingface.co/Phind/Phind-CodeLlama-34B-v2#how-to-prompt-the-model +def phind_codellama_pt(messages): + prompt = "" + for message in messages: + if message["role"] == "system": + prompt += "### System Prompt\n" + message["content"] + "\n\n" + elif message["role"] == "user": + prompt += "### User Message\n" + message["content"] + "\n\n" + elif message["role"] == "assistant": + prompt += "### Assistant\n" + message["content"] + "\n\n" + return prompt + +# Custom prompt template +def custom_prompt(role_dict: dict, pre_message_sep: str, post_message_sep: str, messages: list): + prompt = "" + for message in messages: + if message["role"] == "system": + prompt += f"{role_dict['system']}{pre_message_sep}" + message["content"] + post_message_sep + elif message["role"] == "user": + prompt += f"{role_dict['user']}{pre_message_sep}" + message["content"] + post_message_sep + elif message["role"] == "assistant": + prompt += f"{role_dict['assistant']}{pre_message_sep}" + message["content"] + post_message_sep + return prompt + +def prompt_factory(model: str, messages: list): + model = model.lower() + if "bloom" in model: + return default_pt(messages=messages) + elif "flan-t5" in model: + return default_pt(messages=messages) + elif "meta-llama" in model: + if "chat" in model: + return llama_2_chat_pt(messages=messages) + else: + return default_pt(messages=messages) + elif "falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template. + if "instruct" in model: + return falcon_instruct_pt(messages=messages) + else: + return default_pt(messages=messages) + elif "mpt" in model: + if "chat" in model: + return mpt_chat_pt(messages=messages) + else: + return default_pt(messages=messages) + elif "codellama/codellama" in model: + if "instruct" in model: + return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions + else: + return default_pt(messages=messages) + elif "wizardcoder" in model: + return wizardcoder_pt(messages=messages) + elif "phind-codellama" in model: + return phind_codellama_pt(messages=messages) + else: + return default_pt(messages=messages) \ No newline at end of file diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index dcd7c3efd..51b61b0d8 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -7,6 +7,7 @@ import time from typing import Callable from litellm.utils import ModelResponse from typing import Optional +from .huggingface_model_prompt_templates.factory import prompt_factory, custom_prompt class HuggingfaceError(Exception): def __init__(self, status_code, message): @@ -33,6 +34,7 @@ def completion( encoding, api_key, logging_obj, + custom_prompt_dict={}, optional_params=None, litellm_params=None, logger_fn=None, @@ -47,21 +49,12 @@ def completion( completion_url = os.getenv("HF_API_BASE", "") else: completion_url = f"https://api-inference.huggingface.co/models/{model}" - prompt = "" - if ( - "meta-llama" in model and "chat" in model - ): # use the required special tokens for meta-llama - https://huggingface.co/blog/llama2#how-to-prompt-llama-2 - prompt = "" - for message in messages: - if message["role"] == "system": - prompt += "[INST] <>" + message["content"] - elif message["role"] == "assistant": - prompt += message["content"] + "[INST]" - elif message["role"] == "user": - prompt += message["content"] + "[/INST]" + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt(role_dict=model_prompt_details["roles"], pre_message_sep=model_prompt_details["pre_message_sep"], post_message_sep=model_prompt_details["post_message_sep"], messages=messages) else: - for message in messages: - prompt += f"{message['content']}" + prompt = prompt_factory(model=model, messages=messages) ### MAP INPUT PARAMS data = { "inputs": prompt, diff --git a/litellm/main.py b/litellm/main.py index a5a9b0b38..05b0f1981 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -563,8 +563,8 @@ def completion( logger_fn=logger_fn, encoding=encoding, api_key=huggingface_key, - logging_obj=logging - + logging_obj=logging, + custom_prompt_dict=litellm.custom_prompt_dict ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, diff --git a/litellm/tests/test_hf_prompt_templates.py b/litellm/tests/test_hf_prompt_templates.py new file mode 100644 index 000000000..0066a9e9d --- /dev/null +++ b/litellm/tests/test_hf_prompt_templates.py @@ -0,0 +1,43 @@ +# import sys, os +# import traceback +# from dotenv import load_dotenv + +# load_dotenv() +# import os + +# sys.path.insert( +# 0, os.path.abspath("../..") +# ) # Adds the parent directory to the system path +# import pytest +# import litellm +# from litellm import embedding, completion, text_completion + +# def logger_fn(user_model_dict): +# return +# print(f"user_model_dict: {user_model_dict}") + +# messages=[{"role": "user", "content": "Write me a function to print hello world"}] + +# # test if the first-party prompt templates work +# def test_huggingface_supported_models(): +# model = "huggingface/WizardLM/WizardCoder-Python-34B-V1.0" +# response = completion(model=model, messages=messages, max_tokens=256, api_base="https://ji16r2iys9a8rjk2.us-east-1.aws.endpoints.huggingface.cloud", logger_fn=logger_fn) +# print(response['choices'][0]['message']['content']) +# return response + +# test_huggingface_supported_models() + +# # test if a custom prompt template works +# litellm.register_prompt_template( +# model="togethercomputer/LLaMA-2-7B-32K", +# roles={"system":"", "assistant":"Assistant:", "user":"User:"}, +# pre_message_sep= "\n", +# post_message_sep= "\n" +# ) +# def test_huggingface_custom_model(): +# model = "huggingface/togethercomputer/LLaMA-2-7B-32K" +# response = completion(model=model, messages=messages, api_base="https://ecd4sb5n09bo4ei2.us-east-1.aws.endpoints.huggingface.cloud", logger_fn=logger_fn) +# print(response['choices'][0]['message']['content']) +# return response + +# test_huggingface_custom_model() \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index d611c05ed..3f403ce70 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1326,6 +1326,27 @@ def modify_integration(integration_name, integration_params): Supabase.supabase_table_name = integration_params["table_name"] +# custom prompt helper function +def register_prompt_template(model: str, roles: dict, pre_message_sep: str, post_message_sep: str): + """ + Example usage: + ``` + import litellm + litellm.register_prompt_template( + model="bloomz", + roles={"system":"<|im_start|>system", "assistant":"<|im_start|>assistant", "user":"<|im_start|>user"} + pre_message_sep: "\n", + post_message_sep: "<|im_end|>\n" + ) + ``` + """ + litellm.custom_prompt_dict[model] = { + "roles": roles, + "pre_message_sep": pre_message_sep, + "post_message_sep": post_message_sep + } + return litellm.custom_prompt_dict + ####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging @@ -1415,7 +1436,6 @@ def get_model_list(): f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}" ) - ####### EXCEPTION MAPPING ################ def exception_type(model, original_exception, custom_llm_provider): global user_logger_fn, liteDebuggerClient diff --git a/pyproject.toml b/pyproject.toml index a733fdd0b..d56f01f9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.531" +version = "0.1.532" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"