diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index 8de1d76380..0d73a670da 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index 402182564a..e6cb1a5594 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/llms/baseten.py b/litellm/llms/baseten.py index 49753d67b3..b11ae179dd 100644 --- a/litellm/llms/baseten.py +++ b/litellm/llms/baseten.py @@ -60,9 +60,12 @@ class BasetenLLM: else: prompt += f"{message['content']}" data = { - "prompt": prompt, + # "prompt": prompt, + "inputs": prompt, # in case it's a TGI deployed model # "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg - **optional_params, + # **optional_params, + "parameters": optional_params, + "stream": True if "stream" in optional_params and optional_params["stream"] == True else False } ## LOGGING @@ -76,8 +79,9 @@ class BasetenLLM: self.completion_url_fragment_1 + model + self.completion_url_fragment_2, headers=self.headers, data=json.dumps(data), + stream=True if "stream" in optional_params and optional_params["stream"] == True else False ) - if "stream" in optional_params and optional_params["stream"] == True: + if 'text/event-stream' in response.headers['Content-Type'] or ("stream" in optional_params and optional_params["stream"] == True): return response.iter_lines() else: ## LOGGING @@ -117,9 +121,23 @@ class BasetenLLM: model_response["choices"][0]["message"][ "content" ] = completion_response["completion"] + elif isinstance(completion_response, list) and len(completion_response) > 0: + if "generated_text" not in completion_response: + raise BasetenError( + message=f"Unable to parse response. Original response: {response.text}", + status_code=response.status_code + ) + model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"] + ## GETTING LOGPROBS + if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]: + sum_logprob = 0 + for token in completion_response[0]["details"]["tokens"]: + sum_logprob += token["logprob"] + model_response["choices"][0]["message"]["logprobs"] = sum_logprob else: - raise ValueError( - f"Unable to parse response. Original response: {response.text}" + raise BasetenError( + message=f"Unable to parse response. Original response: {response.text}", + status_code=response.status_code ) ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. diff --git a/litellm/main.py b/litellm/main.py index 285aed48f5..3b0bbc6450 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1,4 +1,4 @@ -import os, openai, sys, json +import os, openai, sys, json, inspect from typing import Any from functools import partial import dotenv, traceback, random, asyncio, time, contextvars @@ -682,7 +682,7 @@ def completion( litellm_params=litellm_params, logger_fn=logger_fn, ) - if "stream" in optional_params and optional_params["stream"] == True: + if inspect.isgenerator(model_response) or ("stream" in optional_params and optional_params["stream"] == True): # don't try to access stream object, response = CustomStreamWrapper( model_response, model, custom_llm_provider="baseten", logging_obj=logging diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index d890c54153..93f7fedc89 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -20,7 +20,7 @@ litellm.use_client = True # litellm.set_verbose = True # litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"]) -user_message = "Hello, whats the weather in San Francisco??" +user_message = "write me a function to print hello world in python" messages = [{"content": user_message, "role": "user"}] @@ -383,6 +383,14 @@ def test_completion_with_fallbacks(): except Exception as e: pytest.fail(f"Error occurred: {e}") +# def test_baseten(): +# try: + +# response = completion(model="baseten/RqgAEn0", messages=messages, logger_fn=logger_fn) +# # Add any assertions here to check the response +# print(response) +# except Exception as e: +# pytest.fail(f"Error occurred: {e}") # def test_baseten_falcon_7bcompletion(): # model_name = "qvv0xeq" diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 306a317eb9..047fc45537 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -25,23 +25,24 @@ user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] # test on baseten completion call -try: - response = completion( - model="wizard-lm", messages=messages, stream=True, logger_fn=logger_fn - ) - print(f"response: {response}") - complete_response = "" - start_time = time.time() - for chunk in response: - chunk_time = time.time() - print(f"time since initial request: {chunk_time - start_time:.5f}") - print(chunk["choices"][0]["delta"]) - complete_response += chunk["choices"][0]["delta"]["content"] - if complete_response == "": - raise Exception("Empty response received") -except: - print(f"error occurred: {traceback.format_exc()}") - pass +# try: +# response = completion( +# model="baseten/RqgAEn0", messages=messages, logger_fn=logger_fn +# ) +# print(f"response: {response}") +# complete_response = "" +# start_time = time.time() +# for chunk in response: +# chunk_time = time.time() +# print(f"time since initial request: {chunk_time - start_time:.5f}") +# print(chunk["choices"][0]["delta"]) +# complete_response += chunk["choices"][0]["delta"]["content"] +# if complete_response == "": +# raise Exception("Empty response received") +# print(f"complete response: {complete_response}") +# except: +# print(f"error occurred: {traceback.format_exc()}") +# pass # test on openai completion call try: diff --git a/litellm/utils.py b/litellm/utils.py index 61b0d56450..3c4a57ead6 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -735,7 +735,8 @@ def get_optional_params( # use the openai defaults elif custom_llm_provider == "baseten": optional_params["temperature"] = temperature optional_params["stream"] = stream - optional_params["top_p"] = top_p + if top_p != 1: + optional_params["top_p"] = top_p optional_params["top_k"] = top_k optional_params["num_beams"] = num_beams if max_tokens != float("inf"): @@ -1739,18 +1740,31 @@ class CustomStreamWrapper: return chunk["choices"][0]["delta"]["content"] def handle_baseten_chunk(self, chunk): - chunk = chunk.decode("utf-8") - data_json = json.loads(chunk) - if "model_output" in data_json: - if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list): - return data_json["model_output"]["data"][0] - elif isinstance(data_json["model_output"], str): - return data_json["model_output"] - elif "completion" in data_json and isinstance(data_json["completion"], str): - return data_json["completion"] + try: + chunk = chunk.decode("utf-8") + if len(chunk) > 0: + if chunk.startswith("data:"): + data_json = json.loads(chunk[5:]) + if "token" in data_json and "text" in data_json["token"]: + return data_json["token"]["text"] + else: + return "" + data_json = json.loads(chunk) + if "model_output" in data_json: + if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list): + return data_json["model_output"]["data"][0] + elif isinstance(data_json["model_output"], str): + return data_json["model_output"] + elif "completion" in data_json and isinstance(data_json["completion"], str): + return data_json["completion"] + else: + raise ValueError(f"Unable to parse response. Original response: {chunk}") + else: + return "" else: - raise ValueError(f"Unable to parse response. Original response: {chunk}") - else: + return "" + except: + traceback.print_exc() return "" def __next__(self): diff --git a/pyproject.toml b/pyproject.toml index 4297dcc947..3f17c7723e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.508" +version = "0.1.509" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"