From aaa57abdddd1065f3755f86dea7f7cfbed03e659 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 13 Sep 2023 19:22:38 -0700 Subject: [PATCH] map finish reason --- litellm/__pycache__/main.cpython-311.pyc | Bin 32927 -> 33061 bytes litellm/__pycache__/utils.cpython-311.pyc | Bin 101105 -> 101105 bytes litellm/llms/ai21.py | 3 +- litellm/llms/anthropic.py | 1 + litellm/llms/huggingface_restapi.py | 5 +- litellm/llms/together_ai.py | 9 ++-- litellm/main.py | 2 + litellm/tests/test_completion.py | 63 +++++++++++----------- litellm/utils.py | 35 ++++++------ pyproject.toml | 2 +- 10 files changed, 64 insertions(+), 56 deletions(-) diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index 63d8dc73ef6e44b85eea47f67255a8010d596f5d..7d1b0da0d773d58fdb6fe518dbe57ffb03f7646a 100644 GIT binary patch delta 1503 zcmZ{hZERCj7{||Z@49Y%@4CKhV`J^Mwp;gBHeqcyVcn~4OmIYAB9yWf*1|S-HwM$q zp@9%|n(6}vqhvEmd;y~57zt`*i}3>q!_5Iu)H{OFdq8mKoP?Y3Obo@pU(i_<;;pO*G?s`gHWUwD1Ia zU1oR^*SadOt>77GkrXrDb{Q0b;lM;I^S!OnE2}v&Tpe4AHw%6-w>{`2{QFQJullL( zWg|6@?6Bwm>K0foRXU01+-dMM&baN+j>f_;e9WdfTsQ!SBpQMv_)bv^d@6DK`+g5- zkKXT0NNhE2Dpqi?2JIP(aA#=|9Ak$sm0GmNAF{2*14TB>E9-!DxVtO>C-B>{5Ugio zR+Se+n2k78epFIC-T-_@H+jEIfz9afzYDM63I9)UnZ6a62hwBE>LbuW!!@m(NxL!H z6YcK`36XGrZ;!A`>PQbSzQb96ljy8|pm$SbperJT!qL#?$dK>{P6jJr2CoH+lV>G| za2=g>ZcyZHtF|&|D#hleW4%Ii9L^EzKv>c_1g# z`KmS-KV4ec=HO2|6cV3yIc4E{T?2C`7BpnZ`B`1iBFA$D!CX21n%r0*#|yVc>vC}>VpBNf9je(F=Rt)BqAsOyUVbYLAl2US~}!A9G+ zO80Wm&4U6+2hMK`fD<#fhror;Z_k9Qbl>(cB+HWsPI`MN!ow=uFj7CdN&_m9JXp`D zbt2EwWTH2TN`@)KkS?Y&RFQP6sAi}kX;x9oP)9UIQP0poQq-c6VH(jIMH52{FMk#u(X)yt?~Y`Zn1== zjX>0{bTTWD)N@Tc&?JKhERdA zU%KkxGko9*!WU0<`p0LBK(p|q1$mFWNI!HdICu%qmvjJSvD#9zX6ZkX0~m9q;%lXy z{K4@6ZWozxqtphM&{7tH!))fMGB<>t&buf%6XhOQrkBdUNq`<4**ODmqtWvd+@<$D znIOGDymtXQ=#qCYM@@B)Im>OS+Ny4N;AnWbJ1m4cdP2S7QQ>zi@K?e*9`w7FF)1ND z#Iyc)O%Hi%pwt!kLn;IF*~jJ}XSQ=acpO=y6SX`?R)uj-;6&uK6#ja{~@+&1(tycb;RUz?f&V1SUUK?QPA1eYGvR>8Jn`QjP zQJ*X0ZDpfF#x-|bxSXt&vD?q0l176vBob9v9_&4}WGHzSybuupdoPh#H znv(rxj08dDGX$;ta!_sX1+yf~Quu;aj161p{GmMnYJ70GN}-ef&{0|tT8@KH>HSU% zz!K%c82~2C?y6R%OZHGMUiRB?qAOM5kmMkp>6+#t5sM>E{i>#`r@v>YJ0yfVh6ehD zd^`|o1p|H_vFAFau<(xjXekm~)piaTL1Q%&~?@c9a;={ov z^&>f4C2xOXqfiJx&|^Xl4=$WO-VA1X^SIPrE{%-{AboJ;ga_<6cOnG&=SI`v51KLB z0ZMru!9h>G9_FDP?~ON1v?qZ|B>(O|tF|{2VmzY>BsoP)WK=~`ETWpxBx10L z8b-AwNiXUcO(qFyQO~G>X!N3yQ8Uq-MGK>;#F8ndF`7xb^(HZk@oXYh&SA8GXwB1| zQ-{PtCSAmk{(TyOxL;Ibnh@q3t5_4?6+-22DF;-njNg7$jh15-A zH%d{~cpH%uMV(+xkRmTCYKtVXR+L1!E0NYT8>!uqLE{`cwSHp;cE1#9HPAfAadpz) z%Agc^UO~P0V(ExutJop$+)VwBLL|$46H9L{K-|W#Bv_R;We6q%1R0>A&Zb1#eZ}_A z!CADjWjpkAq)|`%T72|{OG|1AopzEHYgeB~Y-(o?PGu``n!(!oXB_ zKB!@03j--Sv@>CBT#KD-Vm3Vt6KrKpMKDE!^J|!7KNso<=GZC!3b4eEeRTst=KuCv z3Bu+VeIim>Z@?&`j7_b2bue?M8p%pr=69o!rGb&Z*@F85a`}aA$Psab4`d*dpq{_a zgqPr&r`?DEx;%~9Xh64TH5+5-;ln#H1N8ChCZrHN1NAutV z#(85tB;XYf7N7)}=DRHj7ck5B6~hv)GKZd4%x~;RINx(ag&W>Ju2#i&Y7Yezu2a!Q?k2p#84P3=qN zMW>WyP#6i(e1$i|AS|dLEt0S#YLaD;g+no$KkQ8U5%m1HpNHptxUZYn*`;)LDa8R% z^!AWTAHfw@*u|?BRTXatkbKX0n0HHpS&#%N&=5#YHzbmDYb=?o*N7^eeBZ&DvL#=G zns{3aM9trgM)IxpAoiOoa>2S)_$B!ng2=En6TxIdeX@TCOb_JMMkD72nu|zBeLR`6 z_#=dTvusAFe3YdQb6Xj-GBn&hNdfW{L8BTn5wy;gc!R>IST_(PUVaQAR!a!!IcNOG zCV}KP#z93M)`gNAjice_ILJV)EYo6DwYj+PWI^>QW3j2q?2bZ)n)U_wl+~1$mYYgT z3QrVI!(>h)=B8+}()1RYWTH9PN9T4(aZ3hqHfyn$`kfyK9JI5w9K_KwOkHi=0`#P_ zJq+peUOUf$QQBFgz&zD-5e}etZK=Q_{oE5NKueX^r$waDnZZa6E>V8$i5kiD`BaQ2 zUeLvlN)E4x%kdlIbZMrF!zWrcy9oTG$G)}zt8{j5RRINcd><9Ch4wiU1?19+WxED? zW)%^kz**MskAwguvDce$i^on@6NMB3N7-O3;yKi^_wlfEXtP~RgcIGix+K)1*S4I5 zFCp8+o( z@7Vr5$na9?-QSkZt{gxNu*62P5zAwn?aM(K)XMJUV1Yv->&-(hFQ@-;7}1_M!qmsH zCZLI_OE9Z|ow>@94?JRrjJORvV=EQV0#i1P34Qp$hE2%j;9y>rSW^PbtJY1?v1AE3 TKEah7LKi&#(-PL#f;`0^c#{U+ diff --git a/litellm/llms/ai21.py b/litellm/llms/ai21.py index f3f4a4342..17d5c9bd9 100644 --- a/litellm/llms/ai21.py +++ b/litellm/llms/ai21.py @@ -90,7 +90,8 @@ def completion( else: try: model_response["choices"][0]["message"]["content"] = completion_response["completions"][0]["data"]["text"] - except: + model_response.choices[0].finish_reason = completion_response["completions"][0]["finishReason"]["reason"] + except Exception as e: raise AI21Error(message=json.dumps(completion_response), status_code=response.status_code) ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index 021ec4a73..e1634afe0 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -114,6 +114,7 @@ def completion( model_response["choices"][0]["message"]["content"] = completion_response[ "completion" ] + model_response.choices[0].finish_reason = completion_response["stop_reason"] ## CALCULATING USAGE prompt_tokens = len( diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index e2fccb569..1160e6d8d 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -153,9 +153,10 @@ def completion( elif task == "text-generation-inference": model_response["choices"][0]["message"][ "content" - ] = completion_response[0]["generated_text"] - ## GETTING LOGPROBS + ] = completion_response[0]["generated_text"] + ## GETTING LOGPROBS + FINISH REASON if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]: + model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"] sum_logprob = 0 for token in completion_response[0]["details"]["tokens"]: sum_logprob += token["logprob"] diff --git a/litellm/llms/together_ai.py b/litellm/llms/together_ai.py index 4f75e6e43..47d6ab677 100644 --- a/litellm/llms/together_ai.py +++ b/litellm/llms/together_ai.py @@ -104,14 +104,17 @@ def completion( message=json.dumps(completion_response["output"]), status_code=response.status_code ) - completion_response = completion_response["output"]["choices"][0]["text"] + print(completion_response) + completion_text = completion_response["output"]["choices"][0]["text"] ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len( - encoding.encode(completion_response) + encoding.encode(completion_text) ) - model_response["choices"][0]["message"]["content"] = completion_response + model_response["choices"][0]["message"]["content"] = completion_text + if "finish_reason" in completion_response["output"]["choices"][0]: + model_response.choices[0].finish_reason = completion_response["output"]["choices"][0]["finish_reason"] model_response["created"] = time.time() model_response["model"] = model model_response["usage"] = { diff --git a/litellm/main.py b/litellm/main.py index a7d9d627b..46129c7be 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -529,6 +529,8 @@ def completion( completion_tokens = len(encoding.encode(completion_response)) ## RESPONSE OBJECT model_response["choices"][0]["message"]["content"] = completion_response + if response[0].finish_reason: + model_response.choices[0].finish_reason = response[0].finish_reason model_response["created"] = time.time() model_response["model"] = model model_response["usage"] = { diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 6be7f24d3..934354c2c 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -49,7 +49,7 @@ def test_completion_claude(): print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") - +# test_completion_claude() # aleph alpha # def test_completion_aleph_alpha(): # try: @@ -119,8 +119,8 @@ def test_completion_claude_stream(): # try: # user_message = "write some code to find the sum of two numbers" # messages = [{ "content": user_message,"role": "user"}] -# api_base = "https://wyh9bqfgj2r1klv5.us-east-1.aws.endpoints.huggingface.cloud" -# response = completion(model="facebook/blenderbot-400M-distill", messages=messages, custom_llm_provider="huggingface", task="conversational", api_base=api_base, logger_fn=logger_fn) +# api_base = "https://ecd4sb5n09bo4ei2.us-east-1.aws.endpoints.huggingface.cloud" +# response = completion(model="togethercomputer/LLaMA-2-7B-32K", messages=messages, custom_llm_provider="huggingface", api_base=api_base, logger_fn=logger_fn) # # Add any assertions here to check the response # print(response) # except Exception as e: @@ -141,26 +141,26 @@ def test_completion_claude_stream(): # pytest.fail(f"Error occurred: {e}") -# def test_completion_cohere(): # commenting for now as the cohere endpoint is being flaky -# try: -# response = completion( -# model="command-nightly", -# messages=messages, -# max_tokens=100, -# logit_bias={40: 10}, -# ) -# # Add any assertions here to check the response -# print(response) -# response_str = response["choices"][0]["message"]["content"] -# print(f"str response{response_str}") -# response_str_2 = response.choices[0].message.content -# if type(response_str) != str: -# pytest.fail(f"Error occurred: {e}") -# if type(response_str_2) != str: -# pytest.fail(f"Error occurred: {e}") -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") -## +def test_completion_cohere(): # commenting for now as the cohere endpoint is being flaky + try: + response = completion( + model="command-nightly", + messages=messages, + max_tokens=100, + logit_bias={40: 10}, + logger_fn=logger_fn + ) + # Add any assertions here to check the response + print(response) + response_str = response["choices"][0]["message"]["content"] + print(f"str response{response_str}") + response_str_2 = response.choices[0].message.content + if type(response_str) != str: + pytest.fail(f"Error occurred: {e}") + if type(response_str_2) != str: + pytest.fail(f"Error occurred: {e}") + except Exception as e: + pytest.fail(f"Error occurred: {e}") def test_completion_cohere_stream(): try: @@ -750,15 +750,16 @@ def test_completion_with_fallbacks(): #### Test A121 ################### -# def test_completion_ai21(): -# model_name = "j2-light" -# try: -# response = completion(model=model_name, messages=messages) -# # Add any assertions here to check the response -# print(response) -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") +def test_completion_ai21(): + model_name = "j2-light" + try: + response = completion(model=model_name, messages=messages) + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") +# test_completion_ai21() # test config file with completion # # def test_completion_openai_config(): # try: diff --git a/litellm/utils.py b/litellm/utils.py index fe3efe06b..c5f35cfde 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -830,7 +830,23 @@ def get_optional_params( # use the openai defaults optional_params["top_k"] = top_k if stop != None: optional_params["stop_sequences"] = stop - + elif custom_llm_provider == "huggingface": + if temperature != 1: + optional_params["temperature"] = temperature + if top_p != 1: + optional_params["top_p"] = top_p + if n != 1: + optional_params["n"] = n + if stream: + optional_params["stream"] = stream + if stop != None: + optional_params["stop"] = stop + if max_tokens != float("inf"): + optional_params["max_new_tokens"] = max_tokens + if presence_penalty != 0: + optional_params["repetition_penalty"] = presence_penalty + optional_params["details"] = True + optional_params["task"] = task elif custom_llm_provider == "together_ai" or ("togethercomputer" in model): if stream: optional_params["stream_tokens"] = stream @@ -867,23 +883,6 @@ def get_optional_params( # use the openai defaults optional_params["num_beams"] = num_beams if max_tokens != float("inf"): optional_params["max_new_tokens"] = max_tokens - elif custom_llm_provider == "huggingface": - if temperature != 1: - optional_params["temperature"] = temperature - if top_p != 1: - optional_params["top_p"] = top_p - if n != 1: - optional_params["n"] = n - if stream: - optional_params["stream"] = stream - if stop != None: - optional_params["stop"] = stop - if max_tokens != float("inf"): - optional_params["max_new_tokens"] = max_tokens - if presence_penalty != 0: - optional_params["repetition_penalty"] = presence_penalty - optional_params["details"] = True - optional_params["task"] = task elif custom_llm_provider == "sagemaker": if "llama-2" in model: # llama-2 models on sagemaker support the following args diff --git a/pyproject.toml b/pyproject.toml index 5e247bbf8..6585ca5ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.620" +version = "0.1.621" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"