From 3b4064a58faa0c70ed32aedb252123a9c183d3d4 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 14 Sep 2023 11:17:38 -0700 Subject: [PATCH] move cohere to http endpoint --- litellm/__pycache__/main.cpython-311.pyc | Bin 33733 -> 32578 bytes litellm/__pycache__/utils.cpython-311.pyc | Bin 104989 -> 105961 bytes litellm/llms/cohere.py | 101 ++++++++++++++++++++++ litellm/main.py | 47 ++++------ litellm/tests/test_completion.py | 5 +- litellm/tests/test_streaming.py | 26 +++++- litellm/utils.py | 39 +++++++-- pyproject.toml | 2 +- 8 files changed, 175 insertions(+), 45 deletions(-) create mode 100644 litellm/llms/cohere.py diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index 6b42eec1787cf7b4764e13822069750f286c00c5..3b91b0fb33e88bd4119a01f41f07a841536dc4e7 100644 GIT binary patch delta 5787 zcmbt2X>6O4XJ>L4YGx+u!JWlN%E>aulPrbt?{X|9H=oe`HXXZWVcCXSq)hzd2$c1V{O^r+L$Iy&?k*iaM zt8O)*Q%oMXn)|7qzXR>FR+B;v7XLgB;kHFGm_!cB8|&LLzt0B@N;X5Q3K3y-k!qqO%Jf#U=w2?U|j z(xA(1h9i~`dnEO$Ws%`9Y>nyIBk*GBn|e$1C#ktIqT+N@s<>hY(-%cgr1scyktKD= z{(Gi3NByaxOYb*6VY65O;8el!J zuO@IHH~KG{v|;z)%I-4~H)?kI*sPu?3Br%h>Imnr$j6g*)O&2h&LQ}7b051Gx>|axSjEsdLcwdM`n!z78mr4w+w#m$aGS8J2Ym(`Fk$oFd!TJV(Y6Ph-9 zNZBAh@^A&migXgJbv5ji(A}pllAjHITJzGb>sSVFh*C)0$rkwmtrev=Xsz-C(6M6+ zbhg&Rf}_|uvLfw_bcylud`ULSQrMmBkCrt_Oj1{?7;5vwwO~aVGCUaeM7+s?$nNAI zleGS8qj1a7z{cQKWw~k3iom4+DBB!t5}w|6L_ctC9IS2SaA#X_&X6{AZ32E(T?Vn4 zqP*lzouuuY*$Ypk{{*jiG_{cHF_A?A+fV!h^pANSl>BF8mp0d;CEa@kHC)tBH< zw+&O@+ExL{?i&5>b(C`EA=;KvZMP;IVzm2irP`JIWF0mYu zZ^`^~Vt+!!UKOz)Q@>9!Ya08!R99a?;mhC5Nyy)^k8!+e)fLhI5Vt#fqu&vEek!>K zF57JIH#>}uXmVT|m!;+VXhb}^Y>-8Lbwwb$FbBWb(ZwVgiE%kwLD3I_+0&#vn?z_V}cnOZ2 zOyQ>qgb2(Km?aPGOFX+=bp=f+gp(Me!CwY8dCwDzkVb+WIYLp|80XmjRW4T7EFVRdv_VzQszNFky2j!D!NG~O zc7K4F3gbK!RsClaJ{iLvPxp@te;cpS=y zIyP?<%&!pBMsfT(I6c&86lOjSmxd18R%7eXuMZsjY3O*->KcCnMZx8r)$q`8-}W^% zA{gYv=j)<{tg6(nhih2%(?paMRUf`Z~s*KSOgDcw9&BzF}$IhWA~y zjy2Q1*s8~&d1UOa;@eGFub1sXqv$N;cnl9~j31_X4nEG*z(Q^;91cVR9pD=7hWAFk zcUQU22RF#YOnR1l7I$~;Y6p;pN9~4>+z*Ti$LOfjC979JIU7{s;?+qpbxd;VglS_q zhF37bEi}muhg2mTj)-f@o8gLQzeyaLxCKWc@m3JtuXcy zc>;dq-D>J0RAke0@CWY`ECTbpdkYJ)4`TnUGSBPaPrKXMi_koJ%J?#Ir{Kxa3+xpr z8uOX;VI`V2iab_edhF1V{p&LA3yT_`BUvIYRMYqw0;^SwKTK0msJu|LC)hJhb9(up z-?3?|6VqIz!P*@V6@XrJX{Ka@eh|LBXV2k(Cz3^tO)Dxs^8JC3@ahASmvx8M3xcNu zs$bOUw1IF=Eh}xtn~V26yTJcQd;S6Tjkj2-7w~^1@LvQzkzjtj-LM~VxH&!!uZ_ylpnVUD)<-#Wol{?#lb9NUGiZPE9-9I$6aU;kA3p zVdzK`{QZ~gDbsx;1zY}%+rLxlo+y)kFz%$`XDwbEjX!s_9%z+*(ZX;7@k!ev(tcSo z+C*#b^p0+#F+{spc&l=}g~ksgZ!wMkY4w!Rc)M9>x; zNVIg@WsMi>GsU~jH2bi|pf*Y$Hi)GU8?EYg{fFB}hjGa*-RW{S63@qF?s6Lcvejdu z@zcT4RvKR)ZSr;sWTLFVJ<+3oeWJF+J<*9{XAaI&O0a}&qV!1Ta)k#%3B{3<&Qr0| z1w+{C_QNxBql04VFo=}k^sfzG%CZ|&KCkeE0ZH}vVXgnIT$5wN#i!s?`D;+P$8MzS z$4{pIPWgt8sW5W7{M7#^Raft?=(K2)Ab3c}iv$+n<IHvdZZZ;5MRul*fjGV*WJmov!~d`G zfc!l*rfhgo5ATCFXB#YZN9icZ;smdOk7oNZ#oNPmy-Jndh5XK*Kzq`yif z!ap`HQZdGmiiZk4o49zHoFHr-!BMUeDprz&Jwjj+0rosm5!OoLmx-$5j}pfP0#6cn zioit#D@$~G3tJDy>d)YV4Q?hj_g$ye#)bfycs1_SHgJiWGeu->DS#TT3pLfl3}4o z&-#<P~7kX`hZHaIjt%v`Cj&s!PW3g23&@ras8fx(cA z{|hk%W6`P5v_HV76YFz0ZCzeimm4nB*a4p`+`ENtibFTXM`#@5zrjHqL|TuY@;7Mu zu>=d}S}XS=j#nlSm17a!K?@TEy5WcCuIaGB-K#Y&))S?HKqG-B0-F(}&Dc2nt1q?Z zh~7+eQ#clo{c;TJ(i!o8L*j7LhRIkgoNd=RQ^=yG_4ukt8{|+hd*3+KB9Y>p4vA9; z`IR1$OTJ|vQRJIA(>r-Pu`A7mm0kQWuHs)QalKntWYTgD_%vs7Zs6biIR zGFvkiffh@7ri@jfB~qa^Q!3C3$y$-I3A9qORcGu1t(F$8<(V3R*Gj0lPN1z)S@k!) z&kkoCV!Bl-s<;%qcq-E-rrV^Fs&6(uTa$5$X}e^xT}oVxWIDujoJpqAOWQ9xGZRb@ p_A;r|mf4q`9A%QRG;=IFA+K|U^AP_^xlwQ*P8lA3l^H0q{|&3_dW!%6 delta 6517 zcmb6-32q}urms3YwS)_J5}Ps6Lm&PoVxEV zcxs%;`Q(55cHi4~?%Umuzx_Y*_LHiNze-O}72tWZO)Z&UJ)ejL%edLpO zq(s~#mWn9U0#nK^npVuQgGuFJngbM6!0VLSxQUg3SHW8?wTm@CFO=$}X3Mhkf>?V` zzwVxXJ@(7iOxS!6!xpJUY>*np%EU%it1g=D5St)UoqvqyW9i&EbkYVc60!NB?vHkC z$&tg$u8Cjh$*Ln`g<{KTo!hmPxYqSa!3c}3Ub!np^_WuL`&x^4(ox8+h3* zwgc=DI{@~IWLR*xqK7o!AW^k8s3Lt)np{B4l`t)q<{b?M!hT8ig@UvlMjF^lDgR6C z(Ql->NVXjY_Bd3u8(Qi@(8)eYt0FE|ptqA!)~m0t;KN-=L+%7?6JW87KuDDQ9x6q` zp=%!GxXcL;Yu@J9;yHnB)$ltUU3e-XZGr;F;*|4LwAy5B@y*T!9mjq3GLI3sI*9-->=% z*ojC{w8U1GQ1Nw;`cd)0%_+gB6+za5iA&D}`p=&!VUEx%DT1gK1%LBQLolaAAcEY) zPL<@dcOA)DiZR@zoy)V3 zeau+iQK>Xu9=xQ2r2uc@!L34Rn%KJB7rDpc{);O1i)`QKDy30zUViA38s@;eJV$Yg z+STPa(Y8#Psg@sQ-dZ#1VAu0ZtgF_lZeAGy<5;=Vz#gbIkQVk-e(eM6gmf+oq1JoE zuTT{Gglm~=?OacSgB^CZ`kIv%1>TduCq8^pt=Pp@G*7Y5IOlZ1ckamoE=vAP2~?|C z#7B;DFC&P%kMbzd?~Cpk3nXFhx+HY7H#Rlt4kUuj40TOxUs)#8Tl3Vd@{_Eq(88ul zi#N6^W$R}3F6+B|@RA;4s%N*Vwz5LI&G|dfcqpMUq7glc02VzZ9^#tUr@s%=PY$!E zORU6;$yaZcp9YOBWGkyH$W0qj%3kvb zrt3@DaD>~?wQ7T|8#H%M>;=s~W#4jUmF!!i{jBKYQEgn8nf;hUGQs27BfrMu*#q(H zWu?_cMs@I;ke{PVE8?IWMY_7|vY!=IwUgJ`!73{`ig5|DNuMA)C#Y)I-qYg77})$wqa zC@TWi57bT$*3L==e_>6NeysCSUUk2>@<{Ln@OF}2uPrK?S~BR{;^Y_P;SkTmah`_( z`6oON1CWQD4EgGE3?Lp}=BdhKKd;+NCLx>!rr<9@{v%707x`W~GP9js%Q4remFzG9 z$9E_@L?OWe^}R!+OWX$eA6Es7C)oLVCz*y^l&ZJSi@2(*c3{ZIiISpdlq984Jb5um zeDtE~va&wZGlR<3n(?t{j>%I27@`^!k0?h$r9;Q#Z`}4XL47&{Ev4=r}vulQ)Y~ zVw#8?3R6Fh2C#1&qtdZSDL6Xq2}?n*U!JBxB!-ZfE=pm4XgVMTWsgrxoa6_PdKf|C za0SefO0kkDKYX#O=)(veLBQGhS*-mN>{C|Rkgs(i8e%&e>hvvO{~td1r5 zg5gP-4zrsLMZ>rw(^&*u1s8|Rv1D)9=aHww(jEtirTIhS^sMNhD9Dk-6~(k+3il`ZUZ;A2C8m+36Ec@dxvp)jsJV5! z3rLjPw-cGt!_)^^nUd&82rfy}NX7c(p^qg)LPH|NA`!4BfB}yz>H1JR{AwmhOwmQ0IUQN2jnE1AQdiQ6Q7O#71u&V8af&yBnART}_4*^!fE}IR zCrf^RfKuquQug)EooOqD40Zm?PK%b;${CjC8Y6>@x*S$4nV4n*oQ$RRtbTAbpojtX z16M^_EuwtQKF@yTdXo6rWLM*WHnGxTsfpHeRGN-yd;C7w$YROkQc$8^S)$LPbhTd! z#w=d=Im8LVxD4n!%jhf%&@R?83rLGj#; z5h;;tKzi!v5wGmwCr3tR2E-6uhypW&Q0nYPca`y5dfVI6(d8q6&_);DOe;f99qI0phenDTv z^aqM_7q)D>zG+}_)4+Ax_C?$F1-V~DvAx;U zZNv7%y#7*bKkCpNQ41fdxYUo;#@Qn|>W_1}Tab1uqu+_`r>)&N*xoj9yKa|gB>T2)Goh)jbKH~!ziJAXqCXq|iFOu`G6*Hl72aeizc%s<B`$)0ifm^Pd=D5y&%PNMAsx@imkfOS=f^LCHuk#QPCBB$mW{;8 z7jP2&Clr%58VZj2#yw%Le1vAeNaD!ASV4lCSj|+P6-C92ON^@pPJ4mL#l)%orHLzG zKkWogM<0C;J5sa0If#C0D?MZJ3i0eK3+776+vwX zHXJ}$hXM-0hQaEkxn$uk?%jD$k|#BC>O3|Qot8ZKlp zjp!RqR2tusD6Vgc8KM|JdVz-x9ZOt((J`LT-z2xenLO#2I_O@Y!oLuCK6>uNHdXZF zleL5#XBK70>~TFwjSH*rMp{N(&AVE`kR8`?G+D^ZiKlQhRmd^N(>SUZvdr;xj%EmY zW8A>eOd&ZfZscf|kZz1;bJQdlO>r|v^8|B2+``cUVb+)%w{pBt0L?`lEf;bMF7!Ox z9(V9=g^*#I8$Ek4?&RG{AuInv$+OmY74KFH2J>9xY#?64y8}c>%bweOwkp1z@PQpf y$Tr6Z6P-Oo&}PT?COYV~ZOQphbt;wWwgBK$wV+BqZHg!J*>d#M%o{|5LH&OP1VC2+ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index fbebc9cd5ea033f7640906188daf72efe5268db8..76b82a67db3f31c3f7de008a20b84fc3a414b35b 100644 GIT binary patch delta 4389 zcmaJ^3sh9+5x(EPmtB^}3M(M6JQS9P3nI!xL_`E6qVmuP;)ZaEMqz==8lx@=#zd_( z4JR8rrq4q+NqpvUPLh)}iyCavMx&rF=J)z_*LhN&u z-{z{-M)&OJc3dw|$M)*IDluWRidzjpDi;&In!AKvjcBJcD%h+^UTm^(B_D@`!)!8@ zsSriqsW2T$zMk)_>V%`(1FPT+Yi4u912Y9dw9-x^qL6an0%C%>@eO@fd=vRJ-GI*h zXEBvZ)&DPq(`uM$A0xrsC{bb-owR zQPrpRSjTrM2wsg>5Yy-(4c1y6Q@~1N%OoB60I_gzsE|g-^hl(Wp+A(Q(OE4bgW2YzggRp^k3OzP1g30X=`G6FfF5gpR4SyST5;Bf`QQ06lqJTPSQYz15d zM1@g7jE~Ny>w07d^GhtO`BBb^1N8?nN5QRD%=MkRZcCiw)$miqFmoQ&4ypy=DRu?M zeGv2LoB_7bA55ZR3+&Niv^ZHhcTmGjHT0nc=~Qk+8uDqgk%6FWz&dBjhMh_xegZvR zF(M4}4m`w89j?t`uqL7C4h=~f~fr0aR2=eG26Z9%6 znZgW+54IAum{_)o168tsa>K{hQ8Q%trn7@u@wDqz?#--hODM^VywHVS9Xpt5VyRar zE~Hj766Y`CY|J>YO5HNVOM)OiZd^8QY^68NFh>RPz}DlHKe688?@jEXFqZ`}xyhpC zO2iIkU;cfN;u7i%MU3qxZYCAZA|)FoTympu7EQ^w5^JtCm|}+oxu~G)yuB@*VpuuR z7>r@;095%(uT|b;no`)obaXBZ#laNJ_s!(dA@{)xcB#Z_>I+AH_OctTXNJ;p3MaBD zh&5B9u4bO2>C15B>2LYD%Bw%RoSYG;M=kA&Kn+&VI}uogm2y%fyJ@905CoE}D2d)4 z=IQFPcec8>**#sG+ugl){)_44*Q`)$Qnb?qpnsvJEeh#2 zsXx;!R(0#tV>)fONiN8Q9{9DqC<~8k@U(n62l-mOBgf4{whC9Nx&XVwyBUxqvK#FXo?h3FF34JT8J#1ppHW1B*H7WCOVW?*rdEDc`BrN{K*%-D#5qn zoAif!EWX-qi!b3LJ$}K_#qZNclIHMx0wot? zZlpiFiv7m4Dq}Uq1WxHk#kfRd1+(ld!Mi}XtS-fR z*qhkqwi}W??RR^Ui;}milKVcn&W3cAMd46OEwpnHqT!;0i|`<2l_OqWT*i)58Ml*~ zIa$ZaVYV5`#ED|4ky^NPJtv`@I2d_C9oVdk3tDMs6=LYq#VAFxoL!DJKodQ%1goMH z(+Ml7lLOR2SC(LzE{pA?w~a1Vp-}czzzQvo<7?W#6!C^`#xg)#=-5*BnbJXTE=2?8 zQeq{xXnL8Bj#VI2^-Ixzxi+2aoMml+b7u_TdBSFNKr-QVFw0 zRC(bW%8&B7PMGLw2~1Q}gs{jyl~M3aoEY{Uq$cN5*4i~x%>tvcDkN=CfVQ-CZETm? z?QI*kxHlKsr7e6p1-9|vxuBVc^S=NvMK52%ZnsN0>^^GU*ym%f`>G=MZEiKvrA9b?zw(c3js`S`USgWnsPH{@KA@qKs}bj`dCzELAK<1R#A8WZ(8S z%e|uIXxJdD+NanR7=0QqyqZlj#-G{l$63J@KhAnJ^+Jl^;c(m-se)wg(|Q(;KMMl; zv=jQY_h`3bt8lk^t*{l2@LDN@O;CD-lmLNxZbinu{KX`VFnNOm(n@xB*aMApEuUM< zw4)j~>rc{3wgu^SPLvkQ`|_y4<>ES*OX}o`2RPwHJ*w~8;_2Sv=?#?ED!#@xu5@wk z9?sQqZYwKYL0ZAt9Tc+`x3FG!$6Cx)?N!kw_OFvU;C(BvpNHw?wfIsu9N`!fV(DrV z7Soj`7-@bpifDE-<4!fRp5|*t67@7Anci=P4MX%Nhoo8sydiEh;^_^Q5RawvVt2w!q69qpG zT>vnS6$6101)Uw7Q_$(>oC4=w&OJmI)?-HYxBLwi&(P|h_#0E{wfyxng+@}c3j>h` z?P{i-M5oEWbD8BW|SX9uL@@c4LE<7r{qzMiryTHdf7Gm-E-<0YAq z(Xh5wx2shYU7hW{y{#MC9eU{uJ-GvGLr(KKoYYgOm$`YKzVaeHVEBuq*Cm z#`1R)YwJo^mlWTeY?R;V!*S>o!65Cr7Xu-DO=aN@IldoRI`+Y`Y&SCVd3{Tv9I6&x zGCX>T2dsE=m?XDSNSvJD(swx-rYCp9se6I*XXxT?6l(v(_y^_qJxJ4NFEjpix%$WW z8E~HdzMq|{$L073p;u#&W)5H;cG0>4om&7ulmd A3jhEB delta 4043 zcmaJ^4Nz3q6@K5_KiFLsVS(k(06~6M1X-1cvWnmzF#&>T5K(^g5l56o9~)zI6^(70 zVJ0bd)l(BqW9pbmrTNq3%_K32Y5HFxq_kSnv~e8UI89T?v`rhM?YVDRWCG^#_Ple? zz4x5&e)rt-9@pM5A3tk$JQo#ZmH6uN8d^)QzU=tjmvXV|%+)WK!l3x0C|W`CndM)3 z42B5bn`d_auh?K$FDZ8fj60-Gqmo8_GF;B|0NNbFiMXDUj>}30mC1-Fy8($v4{pEd zL(6sfxVE;Liz75K_6F3*$5n6@2GL+--y8WeG5&F5TbkGB|KEjxIza z?T86?OzP2NLW(HT<^Oo;`F(4_~y{5x8m`Nv1$aUsUpdg!rWgIah z8M!(E*|gt*d1H<3He^#D;GAr#V9zsY2jh#2=rQ@~1G3^Kw*`5aPOB}*K%R(VaF2p| z)54(gCV+AaQ1pz0QR(C1CI*s*83Z$u#@n*#jR-hXCUMc5e~RfjEz^s30L10l)@o37VOvvm>{Dm0|j@)q3=Qbo5=53Mbx)GU^MW-@5Z|{6M zXhS?o=p)0p*y2nHK2j&0sPDrKJ1jQu*gnh(UY?&opA;e{+N>-HnEB}BC<_Cu(?}FM zTN4Glt5md;#`a)zpS=1$MTI9Vx;5zqo)kaHB*m-Ofw3uN;VEg~7{XPq<8t%p@>?jB z#Z<$*-*%>cyh`;=@rq3uJzk{~c3ubNaZc93jK$NFG91y#;ng!ie^r*y!DudYODFhR z7A`6ISRO3BvHUhe^g?jEKI*Y?YU#CL@&F~1&7`szKDZUs5`&`ryGPS6R9NSQv2$Tv zF#&Rgn3+f)#=sMG&oMb*>Am-qfe()%c=Ljb3S)5>Drr+J?m-ouiA4>nHItJc)DVXS z32GfAX;-wrqa(e&#rJT!zhg_QuTy&=4hCapHP^^}(#L7j`i-f>@-B-h>myrye?sn$ z^9O9jgSO&fNuj|+yoW07XcF2@6qAPU$isGRX*zb>U0aP?Bcx9wtiB|}r?KU-FV65; zoXMA@U7P_6a76oaK293(wDzr8C^F$4?e)3Hm+=vOG7pbVY2!tc)GDooBCVB`7}`4@ z4Hnc&?${wGdCEDr|2rQpd_^%OD0E0|dW{!MsdECs{KXj#f8;IR{5Jj7_r2|oVAVE< z8B##@+k>S&E<=JpM(^qhNKIoC6scAU;OiFkGIQm&s#PM$fDpeC;;0ZqLi~#R)q0^- zvo>Vh>~B?vVm5ZPZ`<1H-`wGA>h!CvE$wcTnlA>`Fyt|3l;lOR`LhT~>fA?lUQDHc z7qe$yvQ0U>d@#=4Zxi>`EO-C(l7TGmV3zl>C4=@m`}Ogw=)4!3ySVQZ}F;&~Y1SKNg zxuvLy*GEQ)A#Nes1tsm2xeP0iq4}0^iY^wxcG;8`e@l8>XNPZo`jBH&tFKjU@wX~X z{??s-bvut>KELLxLH8AMC|6(=zZdH&@F4uuSAo7XeGS`0Q2sGOA%r7Uv>5{fPNEO) zMyYugr$lEXt*bx|EnI;{di!2%kJE!fNBq1TW~sXhX>_cT3(i|rn5Ju7S1Wl8bsO0J{A>gQqa@v zv<9HS%d>r#bzuBHzOV*!F(n*02`55`zZ!|8i)%0wv147dVl9$`DNDJd9!n8-tK;Lf zC@9Mtg~gAygs{#CU17Z~F{}mf0DA4AubZ&#nI!giXEP*!QP{*)HLBNfigWlh|3cRu zBOiZbTS~ai3k$1KQ)*ZZ7;B|W$v;2bmZWKtYTsk>dugBnQ+e6{Xu!O#OWeQ84fQ^L zy3;}@Uat~W9B;_z6M0-CGH;_0dSMYH4@ERJiN2;Ly_|$jnNJKkAVjAaV(EC$zwJSP zr&=S9;KRb+B(#9gRtoJQnOQR;ZC~V$R9Kgn%czK z+E4d4;azHJ#*qxG^pk{&)Jm)L>)B-q{OGe+ zY~g<{YUxKU?Byt3Y2kE_4W3^S-?d8g*9lQCgiey+N7cCyH1Hx>=&RQkzE?DLk{%V> z<3fo4PE?N&g+f#bn-G5&71zQ{?Hl_WmTjiF!_ugT7*&($j~mdV?cIpQKsUXs;6Bqq zcJPemYDJ@LI?DQU+OAEo8Izw9-ra(Jr_d}y924R={a_2Go1SJ})85~LwI)18?(O&s zo~9`(((wp+R1~nWS;aK-v!OZNDmIy(XMI5Xw~Dt-=F`Gn6)dPoq2KRBiuo08Uyszj z*ol4_C$wMeh6AS8*_cZo2Jknu($Dv>i?`S^N1MGDSumd!E?%X^E`Eoeqa9tSx4tP% zLM)^|bn&dS^wWLFwVoGdA*$$$eK?Tv4y%H7DLp;JrS3|ITJ7+DoE`OjlBy4&$NDPU zRUzuMYX>mHjMpfk8+k=1g*{48v|j9o0ILdos&M4CsXl#|gm_I@2ZZ>5c64K=Srhs> zI^B(Vri-la)UI_S$7s63`WY?$yLb|uhcESTB0Z~J>w(39lQa_K525{(dkCJzV%t=) zU#i#$RWPdx3df<)xzJZ0B8q#5j2k*SwyGkcdWDW1;sbtNd;Ji)EwOtn$QqW$$#S~? Z7?u}qlZAb9IilPDIllJ+t?w98@qb)(#&iGx diff --git a/litellm/llms/cohere.py b/litellm/llms/cohere.py new file mode 100644 index 000000000..113b6b542 --- /dev/null +++ b/litellm/llms/cohere.py @@ -0,0 +1,101 @@ +import os +import json +from enum import Enum +import requests +import time +from typing import Callable +from litellm.utils import ModelResponse + +class CohereError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + +def validate_environment(api_key): + headers = { + "accept": "application/json", + "content-type": "application/json", + } + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + return headers + +def completion( + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params=None, + litellm_params=None, + logger_fn=None, +): + headers = validate_environment(api_key) + completion_url = "https://api.cohere.ai/v1/generate" + model = model + prompt = " ".join(message["content"] for message in messages) + data = { + "model": model, + "prompt": prompt, + **optional_params, + } + + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) + ## COMPLETION CALL + response = requests.post( + completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False + ) + if "stream" in optional_params and optional_params["stream"] == True: + return response.iter_lines() + else: + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + completion_response = response.json() + if "error" in completion_response: + raise CohereError( + message=completion_response["error"], + status_code=response.status_code, + ) + else: + try: + model_response["choices"][0]["message"]["content"] = completion_response["generations"][0]["text"] + except: + raise CohereError(message=json.dumps(completion_response), status_code=response.status_code) + + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len( + encoding.encode(prompt) + ) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"]["content"]) + ) + + model_response["created"] = time.time() + model_response["model"] = model + model_response["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + return model_response + +def embedding(): + # logic for parsing in - calling - parsing out model embedding calls + pass diff --git a/litellm/main.py b/litellm/main.py index 9edf423e0..26d54cf67 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -32,6 +32,7 @@ from .llms import nlp_cloud from .llms import baseten from .llms import vllm from .llms import ollama +from .llms import cohere import tiktoken from concurrent.futures import ThreadPoolExecutor from typing import Callable, List, Optional, Dict @@ -547,12 +548,6 @@ def completion( input=messages, api_key=openai.api_key, original_response=response ) elif model in litellm.cohere_models: - # import cohere/if it fails then pip install cohere - try: - import cohere - except: - raise Exception("Cohere import failed please run `pip install cohere`") - cohere_key = ( api_key or litellm.cohere_key @@ -560,35 +555,23 @@ def completion( or get_secret("CO_API_KEY") or litellm.api_key ) - co = cohere.Client(cohere_key) - prompt = " ".join([message["content"] for message in messages]) - ## LOGGING - logging.pre_call(input=prompt, api_key=cohere_key) - ## COMPLETION CALL - response = co.generate(model=model, prompt=prompt, **optional_params) + model_response = cohere.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=cohere_key, + logging_obj=logging # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements + ) + if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, - response = CustomStreamWrapper(response, model, logging_obj=logging) + response = CustomStreamWrapper(model_response, model, logging_obj=logging) return response - ## LOGGING - logging.post_call( - input=prompt, api_key=cohere_key, original_response=response - ) - ## USAGE - completion_response = response[0].text - prompt_tokens = len(encoding.encode(prompt)) - completion_tokens = len(encoding.encode(completion_response)) - ## RESPONSE OBJECT - model_response["choices"][0]["message"]["content"] = completion_response - if response[0].finish_reason: - model_response.choices[0].finish_reason = response[0].finish_reason - model_response["created"] = time.time() - model_response["model"] = model - model_response["usage"] = { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens, - } response = model_response elif ( ( diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 98faeb5c2..e8da75a8c 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -144,7 +144,7 @@ def test_completion_nlp_cloud_streaming(): except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_nlp_cloud_streaming() +# test_completion_nlp_cloud_streaming() # def test_completion_hf_api(): # try: # user_message = "write some code to find the sum of two numbers" @@ -183,7 +183,6 @@ def test_completion_cohere(): # commenting for now as the cohere endpoint is bei # Add any assertions here to check the response print(response) response_str = response["choices"][0]["message"]["content"] - print(f"str response{response_str}") response_str_2 = response.choices[0].message.content if type(response_str) != str: pytest.fail(f"Error occurred: {e}") @@ -192,6 +191,8 @@ def test_completion_cohere(): # commenting for now as the cohere endpoint is bei except Exception as e: pytest.fail(f"Error occurred: {e}") +# test_completion_cohere() + def test_completion_cohere_stream(): try: messages = [ diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index c577567b7..43308aa16 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -3,7 +3,7 @@ import sys, os, asyncio import traceback -import time +import time, pytest sys.path.insert( 0, os.path.abspath("../..") @@ -24,6 +24,30 @@ def logger_fn(model_call_object: dict): user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] +def test_completion_cohere_stream(): + try: + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "how does a court case get to the Supreme Court?", + }, + ] + response = completion( + model="command-nightly", messages=messages, stream=True, max_tokens=50 + ) + complete_response = "" + # Add any assertions here to check the response + for chunk in response: + print(f"chunk: {chunk}") + complete_response += chunk["choices"][0]["delta"]["content"] + if complete_response == "": + raise Exception("Empty response received") + print(f"completion_response: {complete_response}") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + # test on baseten completion call # try: # response = completion( diff --git a/litellm/utils.py b/litellm/utils.py index 06149c83e..e0c29896e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1854,7 +1854,7 @@ def exception_type(model, original_exception, custom_llm_provider): llm_provider="replicate", model=model ) - elif model in litellm.cohere_models: # Cohere + elif model in litellm.cohere_models or custom_llm_provider == "cohere": # Cohere if ( "invalid api token" in error_str or "No API key provided." in error_str @@ -1872,6 +1872,21 @@ def exception_type(model, original_exception, custom_llm_provider): model=model, llm_provider="cohere", ) + elif hasattr(original_exception, "status_code"): + if original_exception.status_code == 400 or original_exception.status_code == 498: + exception_mapping_worked = True + raise InvalidRequestError( + message=f"CohereException - {original_exception.message}", + llm_provider="cohere", + model=model + ) + elif original_exception.status_code == 500: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"CohereException - {original_exception.message}", + llm_provider="cohere", + model=model + ) elif ( "CohereConnectionError" in exception_type ): # cohere seems to fire these errors when we load test it (1k+ messages / min) @@ -2287,14 +2302,10 @@ class CustomStreamWrapper: self.model = model self.custom_llm_provider = custom_llm_provider self.logging_obj = logging_obj + self.completion_stream = completion_stream if self.logging_obj: # Log the type of the received item self.logging_obj.post_call(str(type(completion_stream))) - if model in litellm.cohere_models: - # these do not return an iterator, so we need to wrap it in one - self.completion_stream = iter(completion_stream) - else: - self.completion_stream = completion_stream def __iter__(self): return self @@ -2359,6 +2370,16 @@ class CustomStreamWrapper: except: raise ValueError(f"Unable to parse response. Original response: {chunk}") + def handle_cohere_chunk(self, chunk): + chunk = chunk.decode("utf-8") + print(f"cohere chunk: {chunk}") + data_json = json.loads(chunk) + try: + print(f"data json: {data_json}") + return data_json["text"] + except: + raise ValueError(f"Unable to parse response. Original response: {chunk}") + def handle_openai_text_completion_chunk(self, chunk): try: return chunk["choices"][0]["text"] @@ -2416,9 +2437,6 @@ class CustomStreamWrapper: if text_data == "": return self.__next__() completion_obj["content"] = text_data - elif self.model in litellm.cohere_models: - chunk = next(self.completion_stream) - completion_obj["content"] = chunk.text elif self.custom_llm_provider and self.custom_llm_provider == "huggingface": chunk = next(self.completion_stream) completion_obj["content"] = self.handle_huggingface_chunk(chunk) @@ -2440,6 +2458,9 @@ class CustomStreamWrapper: elif self.model in litellm.nlp_cloud_models or self.custom_llm_provider == "nlp_cloud": chunk = next(self.completion_stream) completion_obj["content"] = self.handle_nlp_cloud_chunk(chunk) + elif self.model in litellm.cohere_models or self.custom_llm_provider == "cohere": + chunk = next(self.completion_stream) + completion_obj["content"] = self.handle_cohere_chunk(chunk) else: # openai chat/azure models chunk = next(self.completion_stream) return chunk # open ai returns finish_reason, we should just return the openai chunk diff --git a/pyproject.toml b/pyproject.toml index 5d259cdc3..9a81f0f9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.625" +version = "0.1.626" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"