From cf5060f136f100b905367ac66e74abc7909066f2 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 21 Sep 2023 17:47:52 -0700 Subject: [PATCH] add token trimming to adapt to prompt flag in config --- litellm/__pycache__/utils.cpython-311.pyc | Bin 121407 -> 121893 bytes litellm/tests/test_config.py | 2 +- litellm/utils.py | 26 ++++++++++++++++++---- pyproject.toml | 2 +- 4 files changed, 24 insertions(+), 6 deletions(-) diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index 8734e50251f0bd7130e8edf96aac3a4f07e05db8..486063acbf9fce3718c9f045d9e1223492ebd764 100644 GIT binary patch delta 2697 zcmah~Yiv}<6`pe+yWaP%ckizE;Ro0bUYioY#zn?ZC@(jdhQ_6e39aS&F1GQD&TdU) z&Bazzss7<6Zd@`vLcK&l9N`d3EG?m=)JjYw$5d-%(Hl!qi6XU8D@9Ayrk0wfo!Paq zmm+nfIrq#ra~^Z%oO5qY7%yHo+F!TX%!rL!pIRg3;a}MEVEWf<&dVY7xVkw~8+pdh zH`7Bs8tGI!A}=2?yl_uW+>@bT7;Yr~aT4AEAg>&PJi4_FTmdadHEAXyYGxCPIhKSG zI+{bzzYBTWRm-wU&C+cuMu1eSW@$v-;`3r7Iu68$-_DXWEoo70s(mE>zE5`i6cK4; zGE}g7+<8er=rWs`p;;emwm-|P?V(xQ3bQt1?MN%ZTU8gkf?*f>{$|bosJONYdFutaQHqlMTtGXt#;LAxgMNDXb|5okn1)r+?`VQD5!ER+oyS_VeiP{&m9tOR6XhFD2vf46df@Ky?PoH#WC*9_EvKe=g5W~{!0If z=Wr1hcElcr8zD82E%g&nk(Nu>orVn%q|cs#3P*5-e69(N#5=v+YC$_9Z~!!Gi7YYL%#8nMDO^aaUxi8kye)>?yPCQp(P(pPB#IkYYPBO4 zIQq|5;Us)Te|`>XE#G}ZzynM%Zbo$W91F6a_Wla)dk=B!Ehd83a&m$%zYu8WC~N~C zy*vs#Yv*&b3^X0-Y;K*;JBE9@j>Vcfn{jIdH^thTyPLXt)JW&NsjK<<49i@h4@RNf z`4ZQ9fE645l9PYa^>OekUon2(7Qwg&H!*RwGlHW&`biDAN1UG;X>S|~EEX1yKeA2} z-7eA5IDB$Y_ZN-^&IIC5Up+YM-!kXll9ar9VM$U7=}u+fTPZ=510O4)Ii+H3_pGvU zPT6?1c23zoXdW~#qy&%yy5i>xwp>`yjlmT1%H?$AJaqWhR3xn>dU0h6nQi%LGVDkR zChB+%g3kr+`YMyYO1-q|lvyvSJY~{-PmJZ?^;IW*)q2sBr)0fyb24xJ0ypBy*FB-( zI&Q#HoI--B{9|9qSljjXq_1w)S2yRY8?4h;g$MTx*UVZA?kB$U8Yn`49}CiZ)q!0u z^u9~l0CNg=o6aTyGSo!#iB{ zg7{$#u*HX4B(^<+1cuBAGDBvJZjJ=a`vc-kY1w{7yki5#-*F3^S7gSwBtE(Te-{Ea z7I3_mr$YL@cJcSJbId{i{5q7&PR7?Jykl@k1{?j+oA4)a(r2&0(`+?%1xgH_%;CFN zpj8eq{(FgSZ^1PQ{zD(U4JX*=*aQ@rfuG`JAaP~_Mg?#uY7$@vIn1!jiT&510pzbS zK9RWi4!FRL7q}tfimbfJSo#hd7y}`+fs|g<|^kn^0zclEa+L z()OG1LXc~TvWbd{a5>M)v?@;SC5-RF1TevyzlD?L5-!NeJbmsKd=By}&w$4~&FoWT zV%?Rt{SFTH%Q2ZLSsX+rLLXaaHziH#7F1|Js;C(VV*dueW+X;cT5d+F+`&CU*g_0! z$FRK9XG2Ic3!QY=l&{?FKNDiZBBArF&4A~w5 zA_g3cS(b#EP6Pijj9aoB&|&74X3`{rh^ZBLdTJ=fJX#tI=Tg>ese1Om60wg+EC0pB zOdz|>_OYVYEDX0C5m+ZR^G3-WL%i!XEcVi{E%CQW*e|+%$xPx+oVfY3R3TbB1^b|s z4o*SMDxNZ!=Suo9!`7uf$AE8ebUXcQ3JO|qDPQbkTITI7KgKxH9czm9bVRzN_zAwV zmXUb}jzo|3bVnnLTTv@JU$kOmvgQjG8rCKe`WK60G{UlC0HxNI(ze_1RONktaL}&% z3VG#kN|V98y2Cr%JJcU9yXz=VI?8otL43_vcy4votg~v^zCf?thQAc-453M}n$wyc zh3xZQ05}avJ6DS>=;Z-V;`blGgdiVap4KFOHVqUc>Dzc4EuMkR;7J^wfotM6-ZC*S zQ_RaRU1WGUMCm$7S4F0X*pYeZhdn)=>FeMFv_1*>=5_|NOQGnUL^KH(bHa@P&iJM= z&scQKKgU}n2OdNF=#G!zLQAMf&kO0rWx7jAS!y9Kg&4VSaS!ht+M1F%U_dT!%E-AK al)oxvS~`?InIR`22)QZrzXXSwYySn$6MT9A delta 2154 zcmah}X>3$g6u#%Zd9yMzeVv(Z(`g5aLt798LuILuvIL|QNE%}bN>g}KEZy{WWW5hi zei#$L;s6)4K&!0M$j-3avKHD9+YlYAL`{e>8p%ZeG%C?^r*tDE-sF9A=A3iS{q8yU zoP+00U!E~J4%_Wkf-Cy%u;A#xc1H?aJ~^sMaaf49T3Z>c(kg==l^ZrcR3Z-*$f$?> zh`b9f08h5VbBdXn=vLlphh*8R*}^TY$uNw+wHvZL_CpdO$MEZ2y2V05YPj>To11ro zI@yuhXCb<^#*#&VXinW)LTcnSatT=t%*rg4aUVL48E26<#^|=b5aHQhCydpGgowKv zgg=#-O_#R>M931dCAT+RoGx-Zum|RZ-@ooO&=O6J^Gy_dM;LTF(tCTMdj_*NyAA_O%Fb1u3gOH8`cQ?W`@bKbBm?_hD3GZrz zT+dly{7}Fme5dp;0e|yhO_1vP2jy;ikkVRO`es>;Ru!b-mIgmx+yp5$Bhvl>d#3J& z%=ISt$@M_>>`2}2-QnFdFKY6}h(q!5)1N|Ran_45Vzm3>Q16UMit2k1c>aDM|9oj@ zpy+;}sIBULU`hSLXma3)uPZt4tK__>f81`Rrza+X>WjM5>*x0L^gM97wygbN{g(Bc z=0+pOJ_C=mu?A`17@0O}I=M03Fk6vtx<<^-lWyiov!^Jxl4j%;rjtAN#8>>}j^A2n zmha>isD-k8cPikYyZHt*w}XV5c3DA9yV)f)>DKvPxjm3QKS}Oz0LmS1NyteG%FFwE z8IO(>9)e$_G&5#`&Jl@@XKS7UDn0 z2^eW~_Bu~G0V@;_%5x(JPQWPxJm#;Rg$<_1B48N_$2!ui{|Y${=gz@?2^^8e^WXqw zByQUqIeh_&LCHq>XvB62R7g82(M_VZWdc?q#~hhvh~qN(xXX|XqxhW5P+|=Tm4MUy z^ULrXWbj>AAlo`xs0Cc-ZC7BkHy6b^kpaJ-=8A~%Py*T_QM#7sQF1D39j`);ZBU{8LOWGJz5p>Iy4}@f zYf3}4mBE_2(sFF*>RK&W)$LgR4yD1G&=U(CF8oLF&fDP06M69uta!z{?IBuGU0NNi zt1DX(tfK|!|2sO?O~5lYk=P4<`7*Rs;3|z`H yzJmiV+tsKiHJXwh&B~6do|t_Sq{N5-Gf|Ub79m-QZ+OhsZ-`stu!@_%2j72Ri24-( diff --git a/litellm/tests/test_config.py b/litellm/tests/test_config.py index b2b48cfb3..73e719cad 100644 --- a/litellm/tests/test_config.py +++ b/litellm/tests/test_config.py @@ -88,4 +88,4 @@ def test_config_context_adapt_to_prompt(): print(f"Exception: {e}") pytest.fail(f"An exception occurred: {e}") -# test_config_context_adapt_to_prompt() \ No newline at end of file +test_config_context_adapt_to_prompt() \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 4eff73973..5e96d332e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2802,7 +2802,13 @@ def completion_with_config(config: Union[dict, str], **kwargs): fallback_models = config.get("default_fallback_models", None) available_models = config.get("available_models", None) adapt_to_prompt_size = config.get("adapt_to_prompt_size", False) - start_time = time.time() + trim_messages_flag = config.get("trim_messages", False) + prompt_larger_than_model = False + max_model = model + try: + max_tokens = litellm.get_max_tokens(model)["max_tokens"] + except: + max_tokens = 2048 # assume curr model's max window is 2048 tokens if adapt_to_prompt_size: ## Pick model based on token window prompt_tokens = litellm.token_counter(model="gpt-3.5-turbo", text="".join(message["content"] for message in messages)) @@ -2811,14 +2817,22 @@ def completion_with_config(config: Union[dict, str], **kwargs): except: curr_max_tokens = 2048 if curr_max_tokens < prompt_tokens: + prompt_larger_than_model = True for available_model in available_models: try: curr_max_tokens = litellm.get_max_tokens(available_model)["max_tokens"] + if curr_max_tokens > max_tokens: + max_tokens = curr_max_tokens + max_model = available_model if curr_max_tokens > prompt_tokens: model = available_model + prompt_larger_than_model = False except: continue - end_time = time.time() + if prompt_larger_than_model: + messages = trim_messages(messages=messages, model=max_model) + kwargs["messages"] = messages + kwargs["model"] = model try: if model in models_with_config: @@ -3052,8 +3066,7 @@ def shorten_message_to_fit_limit( # Credits for this code go to Killian Lucas def trim_messages( messages, - model = None, - system_message = None, # str of user system message + model: Optional[str] = None, trim_ratio: float = 0.75, return_response_tokens: bool = False, max_tokens = None @@ -3086,6 +3099,11 @@ def trim_messages( # do nothing, just return messages return + system_message = "" + for message in messages: + if message["role"] == "system": + system_message += message["content"] + current_tokens = token_counter(model=model, messages=messages) # Do nothing if current tokens under messages diff --git a/pyproject.toml b/pyproject.toml index fd729ed5d..eea1fc43e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.731" +version = "0.1.732" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"