From 6ea264d561a80b722ca08973edfba85424aaf872 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 8 Aug 2023 15:57:24 -0700 Subject: [PATCH] streaming for anthropic --- litellm/main.py | 24 +++++++++++- litellm/tests/test_completion.py | 13 +++++++ litellm/utils.py | 67 +++++++++++++++++++------------- 3 files changed, 75 insertions(+), 29 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index 8f7873099..8d8c78e25 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -11,6 +11,19 @@ from litellm.utils import get_secret, install_and_import ####### ENVIRONMENT VARIABLES ################### dotenv.load_dotenv() # Loading env variables using dotenv +# TODO this will evolve to accepting models +# replicate/anthropic/cohere +class CustomStreamWrapper: + def __init__(self, completion_stream): + self.completion_stream = completion_stream + + def __iter__(self): + return self + + def __next__(self): + chunk = next(self.completion_stream) + return {"choices": [{"delta": chunk.completion}]} + new_response = { "choices": [ { @@ -54,7 +67,8 @@ def completion( optional_params = get_optional_params( functions=functions, function_call=function_call, temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens, - presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id + presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id, + model=model ) if azure == True: # azure configs @@ -222,8 +236,14 @@ def completion( completion = anthropic.completions.create( model=model, prompt=prompt, - max_tokens_to_sample=max_tokens_to_sample + max_tokens_to_sample=max_tokens_to_sample, + **optional_params ) + if optional_params['stream'] == True: + # don't try to access stream object, + response = CustomStreamWrapper(completion) + return response + completion_response = completion.completion ## LOGGING logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index e001daa61..35f7b631d 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -26,6 +26,19 @@ def test_completion_claude(): except Exception as e: pytest.fail(f"Error occurred: {e}") +def test_completion_claude_stream(): + try: + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "how does a court case get to the Supreme Court?"} + ] + response = completion(model="claude-2", messages=messages, stream=True) + # Add any assertions here to check the response + for chunk in response: + print(chunk['choices'][0]['delta']) # same as openai format + except Exception as e: + pytest.fail(f"Error occurred: {e}") + def test_completion_hf_api(): try: user_message = "write some code to find the sum of two numbers" diff --git a/litellm/utils.py b/litellm/utils.py index 5b4820131..599c61e24 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -146,36 +146,49 @@ def get_optional_params( frequency_penalty = 0, logit_bias = {}, user = "", - deployment_id = None + deployment_id = None, + model = None, ): optional_params = {} - if functions != []: - optional_params["functions"] = functions - if function_call != "": - optional_params["function_call"] = function_call - if temperature != 1: - optional_params["temperature"] = temperature - if top_p != 1: - optional_params["top_p"] = top_p - if n != 1: - optional_params["n"] = n - if stream: + if model in litellm.anthropic_models: + # handle anthropic params + if stream: optional_params["stream"] = stream - if stop != None: - optional_params["stop"] = stop - if max_tokens != float('inf'): - optional_params["max_tokens"] = max_tokens - if presence_penalty != 0: - optional_params["presence_penalty"] = presence_penalty - if frequency_penalty != 0: - optional_params["frequency_penalty"] = frequency_penalty - if logit_bias != {}: - optional_params["logit_bias"] = logit_bias - if user != "": - optional_params["user"] = user - if deployment_id != None: - optional_params["deployment_id"] = deployment_id - return optional_params + if stop != None: + optional_params["stop_sequences"] = stop + if temperature != 1: + optional_params["temperature"] = temperature + if top_p != 1: + optional_params["top_p"] = top_p + return optional_params + else:# assume passing in params for openai/azure openai + if functions != []: + optional_params["functions"] = functions + if function_call != "": + optional_params["function_call"] = function_call + if temperature != 1: + optional_params["temperature"] = temperature + if top_p != 1: + optional_params["top_p"] = top_p + if n != 1: + optional_params["n"] = n + if stream: + optional_params["stream"] = stream + if stop != None: + optional_params["stop"] = stop + if max_tokens != float('inf'): + optional_params["max_tokens"] = max_tokens + if presence_penalty != 0: + optional_params["presence_penalty"] = presence_penalty + if frequency_penalty != 0: + optional_params["frequency_penalty"] = frequency_penalty + if logit_bias != {}: + optional_params["logit_bias"] = logit_bias + if user != "": + optional_params["user"] = user + if deployment_id != None: + optional_params["deployment_id"] = deployment_id + return optional_params def set_callbacks(callback_list): global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient