diff --git a/litellm/main.py b/litellm/main.py index 17144a47f0..b4a70709bf 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -7,32 +7,9 @@ import litellm from litellm import client, logging, exception_type, timeout, get_optional_params import tiktoken encoding = tiktoken.get_encoding("cl100k_base") -from litellm.utils import get_secret, install_and_import +from litellm.utils import get_secret, install_and_import, CustomStreamWrapper ####### ENVIRONMENT VARIABLES ################### dotenv.load_dotenv() # Loading env variables using dotenv - -# TODO this will evolve to accepting models -# replicate/anthropic/cohere -class CustomStreamWrapper: - def __init__(self, completion_stream, model): - self.model = model - if model in litellm.cohere_models: - # cohere does not return an iterator, so we need to wrap it in one - self.completion_stream = iter(completion_stream) - else: - self.completion_stream = completion_stream - - def __iter__(self): - return self - - def __next__(self): - if self.model in litellm.anthropic_models: - chunk = next(self.completion_stream) - return {"choices": [{"delta": chunk.completion}]} - elif self.model in litellm.cohere_models: - chunk = next(self.completion_stream) - return {"choices": [{"delta": chunk.text}]} - new_response = { "choices": [ { @@ -67,7 +44,7 @@ def completion( presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None, # Optional liteLLM function params *, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False, - hugging_face = False + hugging_face = False, replicate=False, ): try: global new_response @@ -77,7 +54,8 @@ def completion( functions=functions, function_call=function_call, temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id, - model=model + # params to identify the model + model=model, replicate=replicate, hugging_face=hugging_face ) if azure == True: # azure configs @@ -172,7 +150,7 @@ def completion( model_response["model"] = model model_response["usage"] = response["usage"] response = model_response - elif "replicate" in model: + elif "replicate" in model or replicate == True: # import replicate/if it fails then pip install replicate install_and_import("replicate") import replicate @@ -196,6 +174,11 @@ def completion( output = replicate.run( model, input=input) + if 'stream' in optional_params and optional_params['stream'] == True: + # don't try to access stream object, + # let the stream handler know this is replicate + response = CustomStreamWrapper(output, "replicate") + return response response = "" for item in output: response += item diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index d5733e2fb4..304eb0303e 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -139,20 +139,36 @@ def test_completion_azure(): except Exception as e: pytest.fail(f"Error occurred: {e}") - - # Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect. -# [TODO] improve our try-except block to handle for these -# def test_completion_replicate_llama(): -# model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1" -# try: -# response = completion(model=model_name, messages=messages, max_tokens=500) -# # Add any assertions here to check the response -# print(response) -# except Exception as e: -# print(f"in replicate llama, got error {e}") -# pass -# if e == "FunctionTimedOut": -# pass -# else: -# pytest.fail(f"Error occurred: {e}") +def test_completion_replicate_llama_stream(): + model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1" + try: + response = completion(model=model_name, messages=messages, stream=True) + # Add any assertions here to check the response + for result in response: + print(result) + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + +def test_completion_replicate_stability_stream(): + model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb" + try: + response = completion(model=model_name, messages=messages, stream=True, replicate=True) + # Add any assertions here to check the response + for result in response: + print(result) + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + +def test_completion_replicate_stability(): + model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb" + try: + response = completion(model=model_name, messages=messages, replicate=True) + # Add any assertions here to check the response + for result in response: + print(result) + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 04e92737a5..c92440dce9 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -148,6 +148,8 @@ def get_optional_params( user = "", deployment_id = None, model = None, + replicate = False, + hugging_face = False, ): optional_params = {} if model in litellm.anthropic_models: @@ -170,7 +172,12 @@ def get_optional_params( if max_tokens != float('inf'): optional_params["max_tokens"] = max_tokens return optional_params - + elif replicate == True: + # any replicate models + # TODO: handle translating remaining replicate params + if stream: + optional_params["stream"] = stream + return optional_params else:# assume passing in params for openai/azure openai if functions != []: optional_params["functions"] = functions @@ -199,6 +206,7 @@ def get_optional_params( if deployment_id != None: optional_params["deployment_id"] = deployment_id return optional_params + return optional_params def set_callbacks(callback_list): global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient @@ -557,3 +565,30 @@ def get_secret(secret_name): return os.environ.get(secret_name) else: return os.environ.get(secret_name) + +######## Streaming Class ############################ +# wraps the completion stream to return the correct format for the model +# replicate/anthropic/cohere +class CustomStreamWrapper: + def __init__(self, completion_stream, model): + self.model = model + if model in litellm.cohere_models: + # cohere does not return an iterator, so we need to wrap it in one + self.completion_stream = iter(completion_stream) + else: + self.completion_stream = completion_stream + + def __iter__(self): + return self + + def __next__(self): + if self.model in litellm.anthropic_models: + chunk = next(self.completion_stream) + return {"choices": [{"delta": chunk.completion}]} + elif self.model == "replicate": + chunk = next(self.completion_stream) + return {"choices": [{"delta": chunk}]} + elif self.model in litellm.cohere_models: + chunk = next(self.completion_stream) + return {"choices": [{"delta": chunk.text}]} +