streaming for anthropic

This commit is contained in:
ishaan-jaff 2023-08-08 15:57:24 -07:00
parent e4f96075c3
commit 36a6ac9b08
3 changed files with 75 additions and 29 deletions

View file

@ -11,6 +11,19 @@ from litellm.utils import get_secret, install_and_import
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
# TODO this will evolve to accepting models
# replicate/anthropic/cohere
class CustomStreamWrapper:
def __init__(self, completion_stream):
self.completion_stream = completion_stream
def __iter__(self):
return self
def __next__(self):
chunk = next(self.completion_stream)
return {"choices": [{"delta": chunk.completion}]}
new_response = { new_response = {
"choices": [ "choices": [
{ {
@ -54,7 +67,8 @@ def completion(
optional_params = get_optional_params( optional_params = get_optional_params(
functions=functions, function_call=function_call, functions=functions, function_call=function_call,
temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens, temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens,
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id,
model=model
) )
if azure == True: if azure == True:
# azure configs # azure configs
@ -222,8 +236,14 @@ def completion(
completion = anthropic.completions.create( completion = anthropic.completions.create(
model=model, model=model,
prompt=prompt, prompt=prompt,
max_tokens_to_sample=max_tokens_to_sample max_tokens_to_sample=max_tokens_to_sample,
**optional_params
) )
if optional_params['stream'] == True:
# don't try to access stream object,
response = CustomStreamWrapper(completion)
return response
completion_response = completion.completion completion_response = completion.completion
## LOGGING ## LOGGING
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)

View file

@ -26,6 +26,19 @@ def test_completion_claude():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_completion_claude_stream():
try:
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "how does a court case get to the Supreme Court?"}
]
response = completion(model="claude-2", messages=messages, stream=True)
# Add any assertions here to check the response
for chunk in response:
print(chunk['choices'][0]['delta']) # same as openai format
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_hf_api(): def test_completion_hf_api():
try: try:
user_message = "write some code to find the sum of two numbers" user_message = "write some code to find the sum of two numbers"

View file

@ -146,36 +146,49 @@ def get_optional_params(
frequency_penalty = 0, frequency_penalty = 0,
logit_bias = {}, logit_bias = {},
user = "", user = "",
deployment_id = None deployment_id = None,
model = None,
): ):
optional_params = {} optional_params = {}
if functions != []: if model in litellm.anthropic_models:
optional_params["functions"] = functions # handle anthropic params
if function_call != "": if stream:
optional_params["function_call"] = function_call
if temperature != 1:
optional_params["temperature"] = temperature
if top_p != 1:
optional_params["top_p"] = top_p
if n != 1:
optional_params["n"] = n
if stream:
optional_params["stream"] = stream optional_params["stream"] = stream
if stop != None: if stop != None:
optional_params["stop"] = stop optional_params["stop_sequences"] = stop
if max_tokens != float('inf'): if temperature != 1:
optional_params["max_tokens"] = max_tokens optional_params["temperature"] = temperature
if presence_penalty != 0: if top_p != 1:
optional_params["presence_penalty"] = presence_penalty optional_params["top_p"] = top_p
if frequency_penalty != 0: return optional_params
optional_params["frequency_penalty"] = frequency_penalty else:# assume passing in params for openai/azure openai
if logit_bias != {}: if functions != []:
optional_params["logit_bias"] = logit_bias optional_params["functions"] = functions
if user != "": if function_call != "":
optional_params["user"] = user optional_params["function_call"] = function_call
if deployment_id != None: if temperature != 1:
optional_params["deployment_id"] = deployment_id optional_params["temperature"] = temperature
return optional_params if top_p != 1:
optional_params["top_p"] = top_p
if n != 1:
optional_params["n"] = n
if stream:
optional_params["stream"] = stream
if stop != None:
optional_params["stop"] = stop
if max_tokens != float('inf'):
optional_params["max_tokens"] = max_tokens
if presence_penalty != 0:
optional_params["presence_penalty"] = presence_penalty
if frequency_penalty != 0:
optional_params["frequency_penalty"] = frequency_penalty
if logit_bias != {}:
optional_params["logit_bias"] = logit_bias
if user != "":
optional_params["user"] = user
if deployment_id != None:
optional_params["deployment_id"] = deployment_id
return optional_params
def set_callbacks(callback_list): def set_callbacks(callback_list):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient