mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
streaming for anthropic
This commit is contained in:
parent
4a854eb17f
commit
6ea264d561
3 changed files with 75 additions and 29 deletions
|
@ -11,6 +11,19 @@ from litellm.utils import get_secret, install_and_import
|
|||
####### ENVIRONMENT VARIABLES ###################
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
|
||||
# TODO this will evolve to accepting models
|
||||
# replicate/anthropic/cohere
|
||||
class CustomStreamWrapper:
|
||||
def __init__(self, completion_stream):
|
||||
self.completion_stream = completion_stream
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
chunk = next(self.completion_stream)
|
||||
return {"choices": [{"delta": chunk.completion}]}
|
||||
|
||||
new_response = {
|
||||
"choices": [
|
||||
{
|
||||
|
@ -54,7 +67,8 @@ def completion(
|
|||
optional_params = get_optional_params(
|
||||
functions=functions, function_call=function_call,
|
||||
temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id
|
||||
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id,
|
||||
model=model
|
||||
)
|
||||
if azure == True:
|
||||
# azure configs
|
||||
|
@ -222,8 +236,14 @@ def completion(
|
|||
completion = anthropic.completions.create(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
max_tokens_to_sample=max_tokens_to_sample
|
||||
max_tokens_to_sample=max_tokens_to_sample,
|
||||
**optional_params
|
||||
)
|
||||
if optional_params['stream'] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(completion)
|
||||
return response
|
||||
|
||||
completion_response = completion.completion
|
||||
## LOGGING
|
||||
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
|
||||
|
|
|
@ -26,6 +26,19 @@ def test_completion_claude():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
def test_completion_claude_stream():
|
||||
try:
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "how does a court case get to the Supreme Court?"}
|
||||
]
|
||||
response = completion(model="claude-2", messages=messages, stream=True)
|
||||
# Add any assertions here to check the response
|
||||
for chunk in response:
|
||||
print(chunk['choices'][0]['delta']) # same as openai format
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
def test_completion_hf_api():
|
||||
try:
|
||||
user_message = "write some code to find the sum of two numbers"
|
||||
|
|
|
@ -146,9 +146,22 @@ def get_optional_params(
|
|||
frequency_penalty = 0,
|
||||
logit_bias = {},
|
||||
user = "",
|
||||
deployment_id = None
|
||||
deployment_id = None,
|
||||
model = None,
|
||||
):
|
||||
optional_params = {}
|
||||
if model in litellm.anthropic_models:
|
||||
# handle anthropic params
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
if stop != None:
|
||||
optional_params["stop_sequences"] = stop
|
||||
if temperature != 1:
|
||||
optional_params["temperature"] = temperature
|
||||
if top_p != 1:
|
||||
optional_params["top_p"] = top_p
|
||||
return optional_params
|
||||
else:# assume passing in params for openai/azure openai
|
||||
if functions != []:
|
||||
optional_params["functions"] = functions
|
||||
if function_call != "":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue