From 4927e5879fd2069e24c0ab3b0f6efbf8eb78d30e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 30 Aug 2023 19:14:48 -0700 Subject: [PATCH] update baseten handler to handle TGI calls --- litellm/__pycache__/main.cpython-311.pyc | Bin 28027 -> 28135 bytes litellm/__pycache__/utils.cpython-311.pyc | Bin 76385 -> 76900 bytes litellm/llms/baseten.py | 28 +++++++++++++--- litellm/main.py | 4 +-- litellm/tests/test_completion.py | 10 +++++- litellm/tests/test_streaming.py | 35 ++++++++++---------- litellm/utils.py | 38 +++++++++++++++------- pyproject.toml | 2 +- 8 files changed, 79 insertions(+), 38 deletions(-) diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index 8de1d76380133e09dea7491db31f7732e02c6b47..0d73a670dab2bc0d8122036236963c215104ee24 100644 GIT binary patch delta 1096 zcmb7DTTGlq6rS_{yUW6|7hspY0DEB<*h>MST)eb}C52|`8-mCrZMO2S3uLC?ziUDh z8&Wm#p}yceG}Q-fP};<6GiiJ=X*cPM)!3~;CDhUzX&;)#CQXg8YMl8i5BlcG`M>X* zbIyM+k{3MkT{MDwvz-+;v^si<>GTG zc^F}7q&Zdrp^wG)#PAML?EKGW+Ht$wkkQa3aviwCXjD1Z$qt$79U@WlCc;nUYDIT9 zTZJfeu}d%j?W{_P)^%$Ek=>~^2<-W%W>6WZ>(zoH(x-(4_G@8*PxE(_hbHS*451UK z2(0vikAGD0*bIaGl>JWyO8kA-eiOXM@3~K?aGAg1+YgZ7m;E*sR`}IMRfPz@9(oJl z0MAFd050;A&3*;Gf3bLBpxi$uee&gUetvCop!p}HR?)(t?5n~C=JyQ z^%`QhdM_9vX7pM?yft>#|CW4AN($l7pZp1pIE@ mMd!6wNp6qcE3(2-;YSsr2jZc?I%+BZf44mOhu@riSN$6@Y9`SD delta 1001 zcmb7CTS!zv7@q&k-n^`?yX$dpxvsmp>TVayTNkr6vt4`XLWT1nTE`Tdp;+sohY&*W zAxLC`pr@iT%&rU(lrO#(g!PhGbT8?lhy)*uq?x0=^wu!v`@Z@9kDr-2GxrL9&4TlW z)0vCpIM(~!=ooeSfsV#2S83G3ti%e)KnA8ZW+!%219cF`C`}AnXaOrEg)63uxMa{u zU95-{F*k8D5AiTB@v>r4%ygogIU6luKH_73;%5O8U_laOArg``yWya#SScwrbq)=) zGE!#hTpD5Jq}1tL%D&ztgK;tAX*-0x|6{!NGp`LhuwK|p>M!^^$ zG$Fs`K@u`s_|ImNB&jxJ_H)VF8nVXhv!hu?uHtmiMFIX`(+ zGyFsLG1+*6Pe)6S7E$~b?+hyb z6wc_1zt!)oRXm?)@hM)gw41r9?G>qY%9v`=Tb7;H##S($^yb`DCWxmp7I!M6OV+Jw zeFSmWfuFF*@xfYc35iVyZ(-2*myT9>NeMD$%X=@+kbR)w1ApDM*7aG*PYR+67I>gL z1|HGe9mC)ghS84UJ-^%A01-ad+vRS|PO4a!gjqardLa9bclOmnNbtTo7+g}d0{QZ? z9tA!HXO$OGs!yqef+`7%j;-vtafgVaVRMJZc0`CX_p1NYj zp2l>xIit&u`%<;^Rs-w9Ta&AR%Wz3Qja!HnAV&m!~uBjEWyGbge7IPGZEwW?r z(D(w`Q7u`kFFNuL@YA;3vBpq;)K)6nN`so1Aur^S3{Mc7w&aCe;ni|}&E%S>rB1fg z#hD~i2TW~~np>rHTjb`g(h6$ldUnJa=4Dup$kc#Z&!O_R;o^0RS-H0EiYN!B@{x;$BzC+YHH22-#+SUx*D+gC_l^3~dk1|4(U zUAve}Crk1Pp3*O%SeZ%iw9eE}poruuphE>k!fVN}SzSzigx`jHXN(%Jw|moo&+EM! z(P@BHs29@DQX&-VpczcUP95w^DWVjW@QVd|D%fB|aHqmyeU{1`Xa0k#$h}lBBULp? zv`5%%F@XCA|*$WRfihq6@&3ggWGgMv<_}WMR~~Ejl9;Nk}(Fr^H`%>N;RP%9O2h zMibFYBgu_>Xh_EXFf&7-}?w7Y~=jY1ee%t!`;AF z?Bfe(MOKh1ck}(-+XkPR)D{^s~-iiwRECld}Q#&hJ-&bL`}-mEZ&R=KOQOMON`=1Q??& z*Ei7vu3Yjd=*@pr(o5GJw58cb({rAZ1WA)ag8XjEVyxtbrAdPL7ZV|6rKc|YgIg1( zbdf=0_{IfBq3fxb0ga%{rx?~QBrb=)Tq>}CNq2#>Iu*8A+a0TF7k^YG?BezN)R5yy#{F< zTsh$j_802D$lG4sS2N_+ebHl@D>~^e=&kE-(*xCm9T|62cSo~=$nRb3)?ASZ_nLmc z<}T-c_V7@6v|g`m9FH|+Htf+i?BS%*8qj@}nyd1AZ%_+0YmH3?hMF_pxbBT>uK3q0 zR`yiM6?`ub^NlPTO}cFAW)bxw)+^83un*yq6Xi$z8Au8_X)J}D{2>XeBGmqt31EOkGF$zpBW1ehGkPWH4TWKl98thk( rl;WCs)i#2HtneS)sXkkd`w-!U(uMN>;FKXNoms%Yd+t)46=?bk>Z?xY diff --git a/litellm/llms/baseten.py b/litellm/llms/baseten.py index 49753d67b3..b11ae179dd 100644 --- a/litellm/llms/baseten.py +++ b/litellm/llms/baseten.py @@ -60,9 +60,12 @@ class BasetenLLM: else: prompt += f"{message['content']}" data = { - "prompt": prompt, + # "prompt": prompt, + "inputs": prompt, # in case it's a TGI deployed model # "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg - **optional_params, + # **optional_params, + "parameters": optional_params, + "stream": True if "stream" in optional_params and optional_params["stream"] == True else False } ## LOGGING @@ -76,8 +79,9 @@ class BasetenLLM: self.completion_url_fragment_1 + model + self.completion_url_fragment_2, headers=self.headers, data=json.dumps(data), + stream=True if "stream" in optional_params and optional_params["stream"] == True else False ) - if "stream" in optional_params and optional_params["stream"] == True: + if 'text/event-stream' in response.headers['Content-Type'] or ("stream" in optional_params and optional_params["stream"] == True): return response.iter_lines() else: ## LOGGING @@ -117,9 +121,23 @@ class BasetenLLM: model_response["choices"][0]["message"][ "content" ] = completion_response["completion"] + elif isinstance(completion_response, list) and len(completion_response) > 0: + if "generated_text" not in completion_response: + raise BasetenError( + message=f"Unable to parse response. Original response: {response.text}", + status_code=response.status_code + ) + model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"] + ## GETTING LOGPROBS + if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]: + sum_logprob = 0 + for token in completion_response[0]["details"]["tokens"]: + sum_logprob += token["logprob"] + model_response["choices"][0]["message"]["logprobs"] = sum_logprob else: - raise ValueError( - f"Unable to parse response. Original response: {response.text}" + raise BasetenError( + message=f"Unable to parse response. Original response: {response.text}", + status_code=response.status_code ) ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. diff --git a/litellm/main.py b/litellm/main.py index 285aed48f5..3b0bbc6450 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1,4 +1,4 @@ -import os, openai, sys, json +import os, openai, sys, json, inspect from typing import Any from functools import partial import dotenv, traceback, random, asyncio, time, contextvars @@ -682,7 +682,7 @@ def completion( litellm_params=litellm_params, logger_fn=logger_fn, ) - if "stream" in optional_params and optional_params["stream"] == True: + if inspect.isgenerator(model_response) or ("stream" in optional_params and optional_params["stream"] == True): # don't try to access stream object, response = CustomStreamWrapper( model_response, model, custom_llm_provider="baseten", logging_obj=logging diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index d890c54153..93f7fedc89 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -20,7 +20,7 @@ litellm.use_client = True # litellm.set_verbose = True # litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"]) -user_message = "Hello, whats the weather in San Francisco??" +user_message = "write me a function to print hello world in python" messages = [{"content": user_message, "role": "user"}] @@ -383,6 +383,14 @@ def test_completion_with_fallbacks(): except Exception as e: pytest.fail(f"Error occurred: {e}") +# def test_baseten(): +# try: + +# response = completion(model="baseten/RqgAEn0", messages=messages, logger_fn=logger_fn) +# # Add any assertions here to check the response +# print(response) +# except Exception as e: +# pytest.fail(f"Error occurred: {e}") # def test_baseten_falcon_7bcompletion(): # model_name = "qvv0xeq" diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 306a317eb9..047fc45537 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -25,23 +25,24 @@ user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] # test on baseten completion call -try: - response = completion( - model="wizard-lm", messages=messages, stream=True, logger_fn=logger_fn - ) - print(f"response: {response}") - complete_response = "" - start_time = time.time() - for chunk in response: - chunk_time = time.time() - print(f"time since initial request: {chunk_time - start_time:.5f}") - print(chunk["choices"][0]["delta"]) - complete_response += chunk["choices"][0]["delta"]["content"] - if complete_response == "": - raise Exception("Empty response received") -except: - print(f"error occurred: {traceback.format_exc()}") - pass +# try: +# response = completion( +# model="baseten/RqgAEn0", messages=messages, logger_fn=logger_fn +# ) +# print(f"response: {response}") +# complete_response = "" +# start_time = time.time() +# for chunk in response: +# chunk_time = time.time() +# print(f"time since initial request: {chunk_time - start_time:.5f}") +# print(chunk["choices"][0]["delta"]) +# complete_response += chunk["choices"][0]["delta"]["content"] +# if complete_response == "": +# raise Exception("Empty response received") +# print(f"complete response: {complete_response}") +# except: +# print(f"error occurred: {traceback.format_exc()}") +# pass # test on openai completion call try: diff --git a/litellm/utils.py b/litellm/utils.py index 61b0d56450..3c4a57ead6 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -735,7 +735,8 @@ def get_optional_params( # use the openai defaults elif custom_llm_provider == "baseten": optional_params["temperature"] = temperature optional_params["stream"] = stream - optional_params["top_p"] = top_p + if top_p != 1: + optional_params["top_p"] = top_p optional_params["top_k"] = top_k optional_params["num_beams"] = num_beams if max_tokens != float("inf"): @@ -1739,18 +1740,31 @@ class CustomStreamWrapper: return chunk["choices"][0]["delta"]["content"] def handle_baseten_chunk(self, chunk): - chunk = chunk.decode("utf-8") - data_json = json.loads(chunk) - if "model_output" in data_json: - if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list): - return data_json["model_output"]["data"][0] - elif isinstance(data_json["model_output"], str): - return data_json["model_output"] - elif "completion" in data_json and isinstance(data_json["completion"], str): - return data_json["completion"] + try: + chunk = chunk.decode("utf-8") + if len(chunk) > 0: + if chunk.startswith("data:"): + data_json = json.loads(chunk[5:]) + if "token" in data_json and "text" in data_json["token"]: + return data_json["token"]["text"] + else: + return "" + data_json = json.loads(chunk) + if "model_output" in data_json: + if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list): + return data_json["model_output"]["data"][0] + elif isinstance(data_json["model_output"], str): + return data_json["model_output"] + elif "completion" in data_json and isinstance(data_json["completion"], str): + return data_json["completion"] + else: + raise ValueError(f"Unable to parse response. Original response: {chunk}") + else: + return "" else: - raise ValueError(f"Unable to parse response. Original response: {chunk}") - else: + return "" + except: + traceback.print_exc() return "" def __next__(self): diff --git a/pyproject.toml b/pyproject.toml index 4297dcc947..3f17c7723e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.508" +version = "0.1.509" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"