diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index 61e0b648e..be35f5b19 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index 4e033d0a4..0b1e34163 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/main.py b/litellm/main.py index 19c393e49..3ff22ecc3 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -540,16 +540,9 @@ def completion( ## LOGGING logging.pre_call(input=prompt, api_key=TOGETHER_AI_TOKEN) - if stream == True: - return together_ai_completion_streaming( - { - "model": model, - "prompt": prompt, - "request_type": "language-model-inference", - **optional_params, - }, - headers=headers, - ) + + print(f"TOGETHER_AI_TOKEN: {TOGETHER_AI_TOKEN}") + res = requests.post( endpoint, json={ @@ -560,6 +553,12 @@ def completion( }, headers=headers, ) + + if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True: + response = CustomStreamWrapper( + res.iter_lines(), model, custom_llm_provider="together_ai" + ) + return response ## LOGGING logging.post_call( input=prompt, api_key=TOGETHER_AI_TOKEN, original_response=res.text diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 25fa5c047..7b55c9869 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -9,13 +9,14 @@ sys.path.insert( ) # Adds the parent directory to the system path import litellm from litellm import completion -litellm.logging = True -litellm.set_verbose = True +litellm.logging = False +litellm.set_verbose = False score = 0 def logger_fn(model_call_object: dict): + return print(f"model call details: {model_call_object}") @@ -81,17 +82,91 @@ except: # # test on huggingface completion call # try: +# start_time = time.time() # response = completion( -# model="meta-llama/Llama-2-7b-chat-hf", -# messages=messages, -# custom_llm_provider="huggingface", -# custom_api_base="https://s7c7gytn18vnu4tw.us-east-1.aws.endpoints.huggingface.cloud", -# stream=True, -# logger_fn=logger_fn, +# model="gpt-3.5-turbo", messages=messages, stream=True, logger_fn=logger_fn # ) +# complete_response = "" # for chunk in response: +# chunk_time = time.time() +# print(f"time since initial request: {chunk_time - start_time:.2f}") # print(chunk["choices"][0]["delta"]) -# score += 1 +# complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else "" +# if complete_response == "": +# raise Exception("Empty response received") # except: # print(f"error occurred: {traceback.format_exc()}") # pass + +# test on together ai completion call +try: + start_time = time.time() + response = completion( + model="Replit-Code-3B", messages=messages, logger_fn=logger_fn, stream= True + ) + complete_response = "" + print(f"returned response object: {response}") + for chunk in response: + chunk_time = time.time() + print(f"time since initial request: {chunk_time - start_time:.2f}") + print(chunk["choices"][0]["delta"]) + complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else "" + if complete_response == "": + raise Exception("Empty response received") +except: + print(f"error occurred: {traceback.format_exc()}") + pass + + +# # test on azure completion call +# try: +# response = completion( +# model="azure/chatgpt-test", messages=messages, stream=True, logger_fn=logger_fn +# ) +# response = "" +# for chunk in response: +# chunk_time = time.time() +# print(f"time since initial request: {chunk_time - start_time:.2f}") +# print(chunk["choices"][0]["delta"]) +# response += chunk["choices"][0]["delta"] +# if response == "": +# raise Exception("Empty response received") +# except: +# print(f"error occurred: {traceback.format_exc()}") +# pass + + +# # test on anthropic completion call +# try: +# response = completion( +# model="claude-instant-1", messages=messages, stream=True, logger_fn=logger_fn +# ) +# response = "" +# for chunk in response: +# chunk_time = time.time() +# print(f"time since initial request: {chunk_time - start_time:.2f}") +# print(chunk["choices"][0]["delta"]) +# response += chunk["choices"][0]["delta"] +# if response == "": +# raise Exception("Empty response received") +# except: +# print(f"error occurred: {traceback.format_exc()}") +# pass + + +# # # test on huggingface completion call +# # try: +# # response = completion( +# # model="meta-llama/Llama-2-7b-chat-hf", +# # messages=messages, +# # custom_llm_provider="huggingface", +# # custom_api_base="https://s7c7gytn18vnu4tw.us-east-1.aws.endpoints.huggingface.cloud", +# # stream=True, +# # logger_fn=logger_fn, +# # ) +# # for chunk in response: +# # print(chunk["choices"][0]["delta"]) +# # score += 1 +# # except: +# # print(f"error occurred: {traceback.format_exc()}") +# # pass diff --git a/litellm/utils.py b/litellm/utils.py index 96d094088..348304dca 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -371,6 +371,8 @@ def client(original_function): ) if "logger_fn" in kwargs: user_logger_fn = kwargs["logger_fn"] + # LOG SUCCESS + crash_reporting(*args, **kwargs) except: # DO NOT BLOCK running the function because of this print_verbose(f"[Non-Blocking] {traceback.format_exc()}") pass @@ -444,26 +446,27 @@ def client(original_function): function_setup(*args, **kwargs) litellm_call_id = str(uuid.uuid4()) kwargs["litellm_call_id"] = litellm_call_id - # [OPTIONAL] CHECK CACHE start_time = datetime.datetime.now() + # [OPTIONAL] CHECK CACHE if (litellm.caching or litellm.caching_with_models) and ( cached_result := check_cache(*args, **kwargs)) is not None: result = cached_result - else: - # MODEL CALL - result = original_function(*args, **kwargs) + return result + # MODEL CALL + result = original_function(*args, **kwargs) + if "stream" in kwargs and kwargs["stream"] == True: + return result end_time = datetime.datetime.now() - # Add response to CACHE - if litellm.caching: + # [OPTIONAL] ADD TO CACHE + if (litellm.caching or litellm.caching_with_models): add_cache(result, *args, **kwargs) # LOG SUCCESS - crash_reporting(*args, **kwargs) - my_thread = threading.Thread( target=handle_success, args=(args, kwargs, result, start_time, end_time)) # don't interrupt execution of main thread my_thread.start() + # RETURN RESULT return result except Exception as e: @@ -1465,7 +1468,7 @@ class CustomStreamWrapper: if model in litellm.cohere_models: # cohere does not return an iterator, so we need to wrap it in one self.completion_stream = iter(completion_stream) - elif model == "together_ai": + elif custom_llm_provider == "together_ai": self.completion_stream = iter(completion_stream) else: self.completion_stream = completion_stream