From 6aff47083be659b80e00cb81eb783cb24db2e183 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 14 Aug 2023 13:06:33 -0700 Subject: [PATCH] adding support for meta-llama-2 --- .../docs/completion/huggingface_tutorial.md | 45 +++++++++ docs/my-website/sidebars.js | 2 +- litellm/__init__.py | 21 +++- litellm/__pycache__/__init__.cpython-311.pyc | Bin 3888 -> 3914 bytes litellm/__pycache__/main.cpython-311.pyc | Bin 21832 -> 21013 bytes litellm/__pycache__/utils.cpython-311.pyc | Bin 41362 -> 42672 bytes litellm/llms/anthropic.py | 2 +- litellm/llms/huggingface_restapi.py | 94 ++++++++++++++++++ litellm/main.py | 48 +++------ litellm/tests/test_completion.py | 11 ++ litellm/tests/test_streaming.py | 11 ++ litellm/utils.py | 29 +++++- 12 files changed, 220 insertions(+), 43 deletions(-) create mode 100644 docs/my-website/docs/completion/huggingface_tutorial.md create mode 100644 litellm/llms/huggingface_restapi.py diff --git a/docs/my-website/docs/completion/huggingface_tutorial.md b/docs/my-website/docs/completion/huggingface_tutorial.md new file mode 100644 index 0000000000..ea822dec56 --- /dev/null +++ b/docs/my-website/docs/completion/huggingface_tutorial.md @@ -0,0 +1,45 @@ +# Llama2 - Huggingface Tutorial +[Huggingface](https://huggingface.co/) is an open source platform to deploy machine-learnings models. + +## Call Llama2 with Huggingface Inference Endpoints +LiteLLM makes it easy to call your public, private or the default huggingface endpoints. + +In this case, let's try and call 3 models: +- `deepset/deberta-v3-large-squad2`: calls the default huggingface endpoint +- `meta-llama/Llama-2-7b-hf`: calls a public endpoint +- `meta-llama/Llama-2-7b-chat-hf`: call your privat endpoint + +### Case 1: Call default huggingface endpoint + +Here's the complete example: + +``` +from litellm import completion + +model = "deepset/deberta-v3-large-squad2" +messages = [{"role": "user", "content": "Hey, how's it going?"}] # LiteLLM follows the OpenAI format + +### CALLING ENDPOINT +completion(model=model, messages=messages, custom_llm_provider="huggingface") +``` + +What's happening? +- model - this is the name of the deployed model on huggingface +- messages - this is the input. We accept the OpenAI chat format. For huggingface, by default we iterate through the list and add the message["content"] to the prompt. + +### Case 2: Call Llama2 public endpoint + +We've deployed `meta-llama/Llama-2-7b-hf` behind a public endpoint - `https://ag3dkq4zui5nu8g3.us-east-1.aws.endpoints.huggingface.cloud`. + +Let's try it out: +``` +from litellm import completion + +model = "meta-llama/Llama-2-7b-hf" +messages = [{"role": "user", "content": "Hey, how's it going?"}] # LiteLLM follows the OpenAI format +custom_api_base = "https://ag3dkq4zui5nu8g3.us-east-1.aws.endpoints.huggingface.cloud" + +### CALLING ENDPOINT +completion(model=model, messages=messages, custom_llm_provider="huggingface", custom_api_base=custom_api_base) +``` + diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index d744f1490a..b85a8f7eff 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -22,7 +22,7 @@ const sidebars = { { type: 'category', label: 'completion_function', - items: ['completion/input', 'completion/supported','completion/output'], + items: ['completion/input', 'completion/supported','completion/output', 'completion/huggingface_tutorial'], }, { type: 'category', diff --git a/litellm/__init__.py b/litellm/__init__.py index 017e7f3515..026afcf142 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -11,6 +11,7 @@ anthropic_key = None replicate_key = None cohere_key = None openrouter_key = None +huggingface_key = None vertex_project = None vertex_location = None @@ -62,9 +63,6 @@ open_ai_chat_completion_models = [ "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", - 'gpt-3.5-turbo', - 'gpt-3.5-turbo-16k-0613', - 'gpt-3.5-turbo-16k' ] open_ai_text_completion_models = [ 'text-davinci-003' @@ -111,7 +109,22 @@ vertex_text_models = [ "text-bison@001" ] -model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + vertex_chat_models + vertex_text_models +huggingface_models = [ + "meta-llama/Llama-2-7b-hf", + "meta-llama/Llama-2-7b-chat-hf", + "meta-llama/Llama-2-13b-hf", + "meta-llama/Llama-2-13b-chat-hf", + "meta-llama/Llama-2-70b-hf", + "meta-llama/Llama-2-70b-chat-hf", + "meta-llama/Llama-2-7b", + "meta-llama/Llama-2-7b-chat", + "meta-llama/Llama-2-13b", + "meta-llama/Llama-2-13b-chat", + "meta-llama/Llama-2-70b", + "meta-llama/Llama-2-70b-chat", +] # these have been tested on extensively. But by default all text2text-generation and text-generation models are supported by liteLLM. - https://docs.litellm.ai/docs/completion/supported + +model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + huggingface_models + vertex_chat_models + vertex_text_models ####### EMBEDDING MODELS ################### open_ai_embedding_models = [ diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index 227d66ada62bee5665671e84b2ec06f40aacd3b2..3a9b2741d5df85503b851b0d4244a4783529bafa 100644 GIT binary patch delta 322 zcmdlWcS??TIWI340}zx|+)7z8kynzDXQH|nGh>v%#DdxOf=mqV3@H*V3@MVS3|T@T zbs(^enSo(75JNzeFcU+n@G>B0HB5*hN+d<9g<-NEqnK3+W0Z6{LyGhw#weK-nRLb| zSs+gq$dd!oazI)>MII=s0OTnEd5S5DDN5-~QA(*YlNT_`$)_q+0ZolkPEiIbQb|!s zQO!}!RgF?*WB{_&awgwqR5i#oh%x~4)pHC(K;r5-MnIY&MI*;J*Cfg$*EGsB*DT5` z*F4HR*CNVdvkucTCPwwimsw&M)g~LTUSQPT%*bZK%xF2;jB_aue@1C~dS+gFT4HkQ h=4+h47`Y;WPGAJ$;^4`*xPw`Q#2WZF>+rBJ0{|S=OHBX( delta 285 zcmX>lw?U3~IWI340}v>B+(_x2$ScXnF;U%XV!hnN$u;}}ObqS}DH1IVDUzuSS%Q=0 z7{z6Um>5!pmH{=a2I&HWDB%>T7KSL16voMZjAAzF3@Or!7^9?9WYQUS*6nUViJdmdVmrYNN|MJc9Ar^-|T4VyfVQBF2WDMcBmNI69%MKwnyS0ze? zkpak7%Tdi$i<*3!QB?=bSI;pB0g0>U7y@aA6pb9CT;nL?T$3o1T+=AiT(c;%T=OXN z%?3=%m>AV2UuTJ7RGnqGb?4CBXl1Nww2BB+ zpd`9sbOX%x=*F0D)JFu9Aa6KRje0n!V~z;I3-D#fXbpD~Z%JO%74wh!W3{8ThDOer zO?=2xHUlQ}X#tokq|OMTMc2g%5+R{FT6{$qt(ObIcrh(mZ*oGQr8LL*b5Q~k4YX{% z1lSWIC`HEeDag@H^kE@(Q4f z{wq>d(nu>oehaMv*hI-OLE69%nDi3Yz|_P5h2>#6srC#FrYpn`iKPKt`MBNk zxye+KHd&uCksebk}!r@yqlO$A10DFXBjurW zTKPGF`LuFcSH4UZX#G++rE2F2QP4b^K-=a!=R~jpUv{Ltn)ev+Ag`ighOO*AkX3$^ zuv64ch3NKFvs$t&4n4eQeGjCz(k|MqZGB33+7uzFZFFyHJ1v}b!B$F5w1>zt@AAvg`_Y2 zm&zdN-?1DZWaD3I^YRM9J+MApI8(T(ovBus`=h&JZK)lx_S8-yXsxpyhS?`_^GPRv zG}q@*!WR zWlb^;@mKS9w3k95Z<{ki5#Y-XBB+;^aVWNtrCajSuC?>h>b{x#R^{mJnj5peeuI|@ z7Su1OS5_r!1r|tksNPiy9x5nyjg}iBTCvtKe|XLc5y6)o5YYq6I7B39E|57gC;Hd+ z&SgS{RL9pl?jiRG_lhTldx&)OFx$reQQb{EyuIcw>vpgqdDh!mh@Rw`a{5?R` zEfFO#rRs$tN+W6{u7sk&5SvIcRAJ5B?e|+zjt%g={sCte#q}U66JFHN&Tsfj^Q&N) z#8c$+v3mZAznOfC*VNXvNI;TAX1KB@SUi?c4OPix(AXWoXKnzxS(Wcqf!xV2)Ru2I z9Q7g72;=}_*iH<6F!aWb0?>29Ob)3sWj*|twdXbtA!8W9F#sz8A>lXxPHKIwAL8l+ z0t_x20rc6Xnmof_2M|ZtX8!7?)A>z+OF7v{2Y6#$FKOn9x(_P z)zE5+n*%2#7Yf-=grZY&A7gQ*=k~1xjc=FKx!l;PK$R8So_Y9epo;9}uLkO8jP8Sb zW~l{Nx(|a%CgMs`X5U0Eo?yKs3@1A(W5R+f>{AAZ$%$wr3}-YLN<@Nc{H&}n<1|bo z+Z<8l7#l-FVFXVj(HT-y9&BzKN+9W3B$>i-_8cO`YZDZXMr1{0&!gBBf=Wg(}I6<#=rB4oF{T#>A$)A*v0sJ-mR}RynZC(J$1u->ZbQp#_CycOYgW_-*UJ9 zc*pB)8ChKvgaf7#(`WGhneH~hYQH!v!W0(o#wjHpFlXGY9}9w~-~4w2@_+C0mt}^7 znUd2tN>1M_Iepc(Q0Tr|xKOl_&vbgy!+Yk$z@3@l6B+lN@P8xkp7>l_m3`nvE_J8?!PSPs8ZSUO) zlTHRoVfvt~cN0v1+P!ZxOn}YEy=F*OFf?uGQj8e zk2K;PED_xb1(=HBwLSs_-8w#{!r54VFZ0G;i9F4Z_e$MZ18wco8#9Tt(SD1jeDch8W%!#uWi4Gd3o}1FTH2N05m3y)ME*XAknP9olgE8%X>< zf*&9-PVdr7VZ>_q4A$3H0g-OLo`U^s;r!24gg^ijV$u*uNCM>5l{RyX&&8Pfm17f{#4fZ; zwN+iybZ1-BHr?pdrtO+Z-8^;eACp$C-J{YZqtOXnjjip}O`EihY+YNeQ@8KC<`JSz z*|G2MeD^!w`Ofz`=iIM7M}GEgQt*PwlqWztF)mZ@osSo6A-6qV`)Zp&N!UJSrvf#E z9g(eLTZy0&B>Pka%3`B=SBW6J3{S3&Ik}zKN{nGkq;jk>QZ-hk%+O7__-Sw1ED>lv zHG#SUS{Nd<=!zI5A>yy5<_p5u?NT8qFQ&ztRR#rGLUqcoloF7rp_a`Opic;)6qztl zkfFu!lqwXVWmojndO^Gj_PxyGUTghWE!_gE)vc~pzF8NnR!Yl3lbYVfovMlwH?@Iu zJ+%WiP;yjoJNOZmoaq-oAevgxTESa1AE{PsnOg1BDsqgQEjIoeeZH6<9plgHD=o)T zLW)qs6;-PIIdYMtZI|++yE87sJ46w#6^+r||3lCitu`|yrp~7ZHmYJ;PH^S>Zq#?QthDo=m5t3 zd6H_qHh7cMJJ4yegFki+? zK|+PJ=YG!&nGx<2?-FK+`}kqj%3GbgiIvBlCu_EWLbr}JK*t&pnh=^1u)G+Kv+exv z&Q9Xxjg@w-!f57)D_vSIjt%gc%0Xj}#P>n|cI6S19moXMoYTo_xVx%_T;z9E)y&ob z%Plf5bh0+iVv%@~4I)pa@Yo5USs9M1lhS=jV0Yt-ImSZ4P}CpxF)0y`MH7-+lbf&? zCnz@cBc&LQ>oT-}jR45H0F(SliNXZ_S=F?C7)M4CjsUD%iHu|XF_+7dLGm~Pnwzt! zoBz~Rm9MPU1jm5rat-5= zjWxKX*6CD2hUu2`nRI$uN+f;BvvCO)odG74(TUoJZmG5G?1olUrHzHd{)qqo^1AcU zO>g;)0jSlmlnlk9{Pci@U$5KZ^`g{&DL~mE-4~GJDEm^b zuI1wO0$v|tqx`VD0gkG3?zS?lBiZ0TH_as9z0z5xL)C-)$L>eT-F&pZg;4%LeWwvO z19Kt-`J43}B*s6juNt@$7_xpfiV0aAk0p|FUMLZYCX#-r5m`+`fuwBcU$aUHv1~dc zG034WS2AptpKPen&Y|)L_`?k+T&1Ws?}R@*E%jsWnI0JKp4DPUxgA{J=+Ih`?B)%P z4${E)H`dNpAp}O44RkQxa=|HDTy!Hi0S&iW_P1>36tXCPyouz2i5LN#_pA( zxk%rS!|G5{iZDMa8b|mRGK>7tImi+? zrj)ixsdQi5e>NQRQ+6K3^iq_@W1(nL7QtpVi^6jV5Ad0$YIP5g{Mn{1mlKaR20q|e6b1wKU_iwA=#-8HYC2PfeS+=ZhF0y1Uf5Pz<$K{z< zbgxA&ITp+Z7tIIfCO^=cF7<=Ru958N^DBaIRCTxNL+nne;4MJ>H*^uz8GO^tsIDsr zazN)e%Z8FmgU`?fgKN>?S`mn@;YHmmo!hf@9{4X8STF16t&3aS3kCIy1@&2d{gS=< zn}?n{^yKiB;TQWB?akSi-HUc_R$IiVA8a8LHcSk`(E{1@_1F6N)iYELd1 zOXh|@SvC}}2x_rs1Nig1dO!D?Ih*HM$nz}bd9rHHC(u?5AO(E_cfo_lb!U2ock=dj zKqr&IQs~|<8`4Af=ZZl$bRQXZAok|Xw2L3>>t{vu{(YQYunDlMz2=*&r&u5nV z>NQ7Ofn#0YO`6|M95cB{LG%wN{3vYaL`~U}(2>k@81d78xQ<}2b;AmZ6Zy5%^pNS3rRs|my z!$Vz#%5Fph6f1tspBS1Tb(znG%;c1kVhy_pgR(IYi%x`suql3Nkx0c3R`7C*~0_8g|Zc`GCsw>K6+k_Oa2FiCfD@< diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index 451db73573768976d8c30bc082a4e5d0fd148123..92c3af615dfe18cf545c6a164d9732059e088a3a 100644 GIT binary patch delta 2651 zcmZ`(dr(x@8NcW5y}Mj+*Co5$WqHUOTwcnEh=!m;uqqfkighMxMC4pxVcEyoReW5) zSOr8-^fX9JoFOwwW1FVJ>_18WNjmLh#?GX@%iO8kozmEuP9J|2noc?s|LOVe5gXk zO*B2YMdZohh)8b+v-m-QD#3$p(a&RkP>fi=Tgh=MrwXb_{}8M|^Yn2r4~^37&>oX1 z;-Vjia_E=gmnaRnkcoa2YGvBo6|Fsg3i&&*RTa)~#EGRrwqSKroT>^moj5l#r#;xF za?DA=>tEn(PdewTB2Mq)VHYABJ7%xh5uD~pwtF>R?lV>vk7p-&Oxy^9XVm;f<*;*Q z!^1cOqK5@6zaR@&GK(CXX*fKe0XWPPcbZ!j*>T*EA-x~2w)rosJkDB)GXypC_*K+v zTo|N?Q&w!olBdd)O+%sW>rIih2hFCRhF)R1wZ{X$Li$zM-DX~UcM;oa2~?AcaL&8D zYW_v;PZ=3LdO{GysV2homyn7sAwp$7o zoh~i2Jm!pUAJ~D=GL4+wjFzJxoP7-qFNvIe`;wWn70kZ2V6D-tHH)?kP2PB~DrT$K zZS_l>wC)WQPcPEbH_dHXNUzh<>*5ugH2cOy*`t-8)S8dz%_m?|eljL^=yHcN{VDNy)L8d|a20Yxd*x zKj((zeVg~yat~@#TZ&8%8uME=n0~iGY$=M4oKHdMX7t8|fAFY*4h(&UcF}Fa8`=Ei z@DEd8hEfB(9sP88!IXLgWPs1;TR$G%ZM-8^_pfUi+~ySSD7*TjJWj#hNqN5`j==b#QNHzUp%)^r%^sjj`uOez2%^T*ipz^FznOJ0H z5!`Hx`{*0{~|tUIreMk!J67T2Pj^^@M&jM>mEnI-SLV{(NqSB!b$R_BbQSqs=V+chVBQl@Wi(n{frJ9B5c zb!X}98Qoc}S*ss8<*BB)b^X{OlF5>z|9(^EeLQ_QK8gNCm23OZZ|VJO9|}R0>!tH9 znFc4;p*%V`(NOAdWV$Dvmuy=rVap2nDw(T?{`!QAwqDD$vlk($$cwt!kFh?LKA0$R zku9t;2}@^Bz}u;W6Czf8kq*zgTa$8{LArjZ-f3De=_*RLzyYQJqQEj+XP`gmQ=nzG z2Tt`eP8^)fMT0ajSzvw)%xtC8lmBbhQj6l&blqCA#PQO8bgx)1Z~DA{p}b8iZ(DS@ zW&}O6I_B7i>*4Bzvn1V8q;*BVzy~BUb;Cn(02YucM*$wt2?P=X~Eo+*6)5Q|7 zK7o445S^KRMZCkLl4xRDF`12My>$1TBJnQMHqrh&@8yld9@hXS7xg^^6e+ZPG+DSomy8h$5;eqgj|qc z06y(g8p%CoWgi#zMLl=#n;d%(VlQkB9^&agXSdbtLWo_!_v5mm|LGrER!Zupk9uTU5MMgKe}Y>&~Iwvfay^Q`*95 z*;eV6aw#y}v?LTFkbnfq`e1x9(I`HMF3JO0L+pcSOhlJpFfl}(|E!b-;_RL;^Z)-h z-+wu0KDaC#_)N$@lb7c}e3ds`BR}lAnEzuG3oLQa+ zN|i`e`4fqRk&5r-kN0_vI<`IO`vh*MNDa=;nYyuQLI8nYgVa}6x9f2bFmsTW0>4Mbf3Nj8xpII2m+Jw;+_ z#hzTnrKH#-JX`}1Mm3*D;@x=73a3`!VexncyQzDd z=p27?3nN3I+&_p7JctSEll$dv+>Ls~X4H-SrK#qwE1135eQK_SR`MgX8gp0-Pzg{4 z(8vRtg##<=DOu3Q#S>_vY;IaA%?iuqitF~7Y1h2-)2`T;T(vLBqJ`G-31>yxQE_dW z^A@U3b@Z&mILj{dF2LE;MDO#M{j{y$Tko&0LYJx>4W;6xnmG*~@k@`)_d$`J7;&-v z#w|RQdaLhe0YAkyCqBVXv-ty$^ZhFWFE~~}Q2={W=LfzO9j}57@FCl|BiY6`=DH0H z9YHE;;f=Ql1iy`1L?UpMh{#bU5=RS>Ia-OO3Hhy=nQOZ=J=)o%Q#H`Jkp^Nhl?ErL z$@O3l&+U^eG30{V;c;8|dvohnxDS$n6p39Q@?z7h29vYRMNktm@Pi(|uz?=qalcF} z*}CLoZZRWiYIqZU1G<71G|0!2$pwX9@UK{FZ*T0YW>agZI|HM>{^mb57pA#??h zEVLC=+u7H<_TUVAefJ={vA0P04YRv@_u#W^WV8|QU{^=aNDB?r!7h5-c{YB|Tznj} z+4$%a?5ll+PGTiC1BY-MfB1;)EWH2G?HOyZJsJwCT1F;nSP%Fexv#H)pQm*kGEzkC z(Q`A{425V=pG{G~ttfO25c@emHxSAXMmu9+RS!iYO7zVR+73>;4r~(80~)-AIebw! z-w3pql^-a#+<mc~(xpq&zGdkdHECN7E#}k>ejBX;Er+6XMoCv# zRg}zhMTvI>!jV8Hx2VFyGypJJDeVOlUWqAUCZ#Rl1Na=$Dm<8&{1s#co4hoOCc(wZ z4_3`Cl%LF^hao50lG=8#8snAh?4erA8W68zzaHwg?BLtGLMn9lWvRLj+(CdlP|QAn zm%!`-_z2980k(kI&3<`z>HPJu1!u+<(p9PiBE8i7I%F*t4(W6q#7*5c);P9=-^u>5 zxps5jxPxwG=f+mbhdF0@>gJd#TFh+i?6u>M$nS8kpCyi;C>sKeB)~BEo}=?1pM^wq zHL``bv3k4`H?q$0VqBjZ8js;;%qwr+b90K&A{_yMOcaD4nD$?`+T80}Mc?C5{zqCM nReSP+=&Z#RSu_a-W`Ccm@h`%B(htKe@cnT=5W#HI>Gl5t=pWYr diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index 5b5d928b20..c617a0ae0f 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -13,6 +13,7 @@ class AnthropicError(Exception): def __init__(self, status_code, message): self.status_code = status_code self.message = message + super().__init__(self.message) # Call the base class constructor with the parameters it needs class AnthropicLLM: @@ -75,7 +76,6 @@ class AnthropicLLM: print_verbose(f"raw model_response: {response.text}") ## RESPONSE OBJECT completion_response = response.json() - print(f"completion_response: {completion_response}") if "error" in completion_response: raise AnthropicError(message=completion_response["error"], status_code=response.status_code) else: diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py new file mode 100644 index 0000000000..7b88d03e65 --- /dev/null +++ b/litellm/llms/huggingface_restapi.py @@ -0,0 +1,94 @@ +## Uses the huggingface text generation inference API +import os, json +from enum import Enum +import requests +from litellm import logging +import time +from typing import Callable + +class HuggingfaceError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + super().__init__(self.message) # Call the base class constructor with the parameters it needs + +class HuggingfaceRestAPILLM(): + def __init__(self, encoding, api_key=None) -> None: + self.encoding = encoding + self.validate_environment(api_key=api_key) + + def validate_environment(self, api_key): # set up the environment required to run the model + self.headers = { + "content-type": "application/json", + } + # get the api key if it exists in the environment or is passed in, but don't require it + self.api_key = os.getenv("HF_TOKEN") if "HF_TOKEN" in os.environ else api_key + if self.api_key != None: + self.headers["Authorization"] = f"Bearer {self.api_key}" + + def completion(self, model: str, messages: list, custom_api_base: str, model_response: dict, print_verbose: Callable, optional_params=None, litellm_params=None, logger_fn=None): # logic for parsing in - calling - parsing out model completion calls + if custom_api_base: + completion_url = custom_api_base + elif "HF_API_BASE" in os.environ: + completion_url = os.getenv("HF_API_BASE") + else: + completion_url = f"https://api-inference.huggingface.co/models/{model}" + prompt = "" + if "meta-llama" in model and "chat" in model: # use the required special tokens for meta-llama - https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + prompt = "" + for message in messages: + if message["role"] == "system": + prompt += "[INST] <>" + message["content"] + elif message["role"] == "assistant": + prompt += message["content"] + "[INST]" + elif message["role"] == "user": + prompt += message["content"] + "[/INST]" + else: + for message in messages: + prompt += f"{message['content']}" + ### MAP INPUT PARAMS + # max tokens + if "max_tokens" in optional_params: + value = optional_params.pop("max_tokens") + optional_params["max_new_tokens"] = value + data = { + "inputs": prompt, + # "parameters": optional_params + } + ## LOGGING + logging(model=model, input=prompt, additional_args={"litellm_params": litellm_params, "optional_params": optional_params}, logger_fn=logger_fn) + ## COMPLETION CALL + response = requests.post(completion_url, headers=self.headers, data=json.dumps(data)) + if "stream" in optional_params and optional_params["stream"] == True: + return response.iter_lines() + else: + ## LOGGING + logging(model=model, input=prompt, additional_args={"litellm_params": litellm_params, "optional_params": optional_params, "original_response": response.text}, logger_fn=logger_fn) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + completion_response = response.json() + print(f"response: {completion_response}") + if isinstance(completion_response, dict) and "error" in completion_response: + print(f"completion error: {completion_response['error']}") + print(f"response.status_code: {response.status_code}") + raise HuggingfaceError(message=completion_response["error"], status_code=response.status_code) + else: + model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"] + + ## CALCULATING USAGE + prompt_tokens = len(self.encoding.encode(prompt)) ##[TODO] use the llama2 tokenizer here + completion_tokens = len(self.encoding.encode(model_response["choices"][0]["message"]["content"])) ##[TODO] use the llama2 tokenizer here + + + model_response["created"] = time.time() + model_response["model"] = model + model_response["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens + } + return model_response + pass + + def embedding(): # logic for parsing in - calling - parsing out model embedding calls + pass \ No newline at end of file diff --git a/litellm/main.py b/litellm/main.py index 3abdccddf3..d17bebd0d1 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -7,6 +7,7 @@ import litellm from litellm import client, logging, exception_type, timeout, get_optional_params, get_litellm_params from litellm.utils import get_secret, install_and_import, CustomStreamWrapper, read_config_args from .llms.anthropic import AnthropicLLM +from .llms.huggingface_restapi import HuggingfaceRestAPILLM import tiktoken from concurrent.futures import ThreadPoolExecutor encoding = tiktoken.get_encoding("cl100k_base") @@ -222,7 +223,6 @@ def completion( response = CustomStreamWrapper(model_response, model) return response response = model_response - elif model in litellm.openrouter_models or custom_llm_provider == "openrouter": openai.api_type = "openai" # not sure if this will work after someone first uses another API @@ -305,37 +305,15 @@ def completion( "total_tokens": prompt_tokens + completion_tokens } response = model_response - elif custom_llm_provider == "huggingface": - import requests - API_URL = f"https://api-inference.huggingface.co/models/{model}" - HF_TOKEN = get_secret("HF_TOKEN") - headers = {"Authorization": f"Bearer {HF_TOKEN}"} - - prompt = " ".join([message["content"] for message in messages]) - ## LOGGING - logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn) - input_payload = {"inputs": prompt} - response = requests.post(API_URL, headers=headers, json=input_payload) - ## LOGGING - logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": response.text}, logger_fn=logger_fn) - if isinstance(response, dict) and "error" in response: - raise Exception(response["error"]) - json_response = response.json() - if 'error' in json_response: # raise HF errors when they exist - raise Exception(json_response['error']) - - completion_response = json_response[0]['generated_text'] - prompt_tokens = len(encoding.encode(prompt)) - completion_tokens = len(encoding.encode(completion_response)) - ## RESPONSE OBJECT - model_response["choices"][0]["message"]["content"] = completion_response - model_response["created"] = time.time() - model_response["model"] = model - model_response["usage"] = { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens - } + elif model in litellm.huggingface_models or custom_llm_provider == "huggingface": + custom_llm_provider = "huggingface" + huggingface_key = api_key if api_key is not None else litellm.huggingface_key + huggingface_client = HuggingfaceRestAPILLM(encoding=encoding, api_key=huggingface_key) + model_response = huggingface_client.completion(model=model, messages=messages, custom_api_base=custom_api_base, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn) + if 'stream' in optional_params and optional_params['stream'] == True: + # don't try to access stream object, + response = CustomStreamWrapper(model_response, model, custom_llm_provider="huggingface") + return response response = model_response elif custom_llm_provider == "together_ai": import requests @@ -383,7 +361,7 @@ def completion( prompt = " ".join([message["content"] for message in messages]) ## LOGGING - logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn) + logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"litellm_params": litellm_params, "optional_params": optional_params}, logger_fn=logger_fn) chat_model = ChatModel.from_pretrained(model) @@ -434,13 +412,13 @@ def completion( ## LOGGING logging(model=model, input=messages, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn) args = locals() - raise ValueError(f"Invalid completion model args passed in. Check your input - {args}") + raise ValueError(f"Unable to map your input to a model. Check your input - {args}") return response except Exception as e: ## LOGGING logging(model=model, input=messages, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens}, logger_fn=logger_fn, exception=e) ## Map to OpenAI Exception - raise exception_type(model=model, original_exception=e) + raise exception_type(model=model, custom_llm_provider=custom_llm_provider, original_exception=e) def batch_completion(*args, **kwargs): batch_messages = args[1] if len(args) > 1 else kwargs.get("messages") diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index c64c845362..f639e327db 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -49,6 +49,17 @@ def test_completion_hf_api(): except Exception as e: pytest.fail(f"Error occurred: {e}") +def test_completion_hf_deployed_api(): + try: + user_message = "There's a llama in my garden 😱 What should I do?" + messages = [{ "content": user_message,"role": "user"}] + response = completion(model="meta-llama/Llama-2-7b-chat-hf", messages=messages, custom_llm_provider="huggingface", custom_api_base="https://s7c7gytn18vnu4tw.us-east-1.aws.endpoints.huggingface.cloud", logger_fn=logger_fn) + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + +test_completion_hf_deployed_api() def test_completion_cohere(): try: response = completion(model="command-nightly", messages=messages, max_tokens=500) diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index b7332772fb..317dea904b 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -23,6 +23,17 @@ try: for chunk in response: print(chunk['choices'][0]['delta']) score +=1 +except: + print(f"error occurred: {traceback.format_exc()}") + pass + + +# test on anthropic completion call +try: + response = completion(model="meta-llama/Llama-2-7b-chat-hf", messages=messages, custom_llm_provider="huggingface", custom_api_base="https://s7c7gytn18vnu4tw.us-east-1.aws.endpoints.huggingface.cloud", stream=True, logger_fn=logger_fn) + for chunk in response: + print(chunk['choices'][0]['delta']) + score +=1 except: print(f"error occurred: {traceback.format_exc()}") pass \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 6fc85aaa10..f57c390cbe 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -589,7 +589,7 @@ def modify_integration(integration_name, integration_params): if "table_name" in integration_params: Supabase.supabase_table_name = integration_params["table_name"] -def exception_type(model, original_exception): +def exception_type(model, original_exception, custom_llm_provider): global user_logger_fn exception_mapping_worked = False try: @@ -640,6 +640,17 @@ def exception_type(model, original_exception): elif "CohereConnectionError" in exception_type: # cohere seems to fire these errors when we load test it (1k+ messages / min) exception_mapping_worked = True raise RateLimitError(f"CohereException - {original_exception.message}") + elif custom_llm_provider == "huggingface": + if hasattr(original_exception, "status_code"): + if original_exception.status_code == 401: + exception_mapping_worked = True + raise AuthenticationError(f"HuggingfaceException - {original_exception.message}") + elif original_exception.status_code == 400: + exception_mapping_worked = True + raise InvalidRequestError(f"HuggingfaceException - {original_exception.message}", f"{model}") + elif original_exception.status_code == 429: + exception_mapping_worked = True + raise RateLimitError(f"HuggingfaceException - {original_exception.message}") raise original_exception # base case - return the original exception else: raise original_exception @@ -715,8 +726,9 @@ def get_secret(secret_name): # wraps the completion stream to return the correct format for the model # replicate/anthropic/cohere class CustomStreamWrapper: - def __init__(self, completion_stream, model): + def __init__(self, completion_stream, model, custom_llm_provider=None): self.model = model + self.custom_llm_provider = custom_llm_provider if model in litellm.cohere_models: # cohere does not return an iterator, so we need to wrap it in one self.completion_stream = iter(completion_stream) @@ -745,6 +757,16 @@ class CustomStreamWrapper: return extracted_text else: return "" + + def handle_huggingface_chunk(self, chunk): + chunk = chunk.decode("utf-8") + if chunk.startswith('data:'): + data_json = json.loads(chunk[5:]) + if "token" in data_json and "text" in data_json["token"]: + return data_json["token"]["text"] + else: + return "" + return "" def __next__(self): completion_obj ={ "role": "assistant", "content": ""} @@ -763,6 +785,9 @@ class CustomStreamWrapper: elif self.model in litellm.cohere_models: chunk = next(self.completion_stream) completion_obj["content"] = chunk.text + elif self.custom_llm_provider and self.custom_llm_provider == "huggingface": + chunk = next(self.completion_stream) + completion_obj["content"] = self.handle_huggingface_chunk(chunk) # return this for all models return {"choices": [{"delta": completion_obj}]}