(feat) maintain anthropic text completion

This commit is contained in:
ishaan-jaff 2024-03-04 11:16:34 -08:00
parent 9094be7fbd
commit 1183e5f2e5
4 changed files with 59 additions and 25 deletions

View file

@ -39,6 +39,7 @@ from litellm.utils import (
)
from .llms import (
anthropic,
anthropic_text,
together_ai,
ai21,
sagemaker,
@ -1018,28 +1019,55 @@ def completion(
or litellm.api_key
or os.environ.get("ANTHROPIC_API_KEY")
)
api_base = (
api_base
or litellm.api_base
or get_secret("ANTHROPIC_API_BASE")
or "https://api.anthropic.com/v1/messages"
)
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = anthropic.completion(
model=model,
messages=messages,
api_base=api_base,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=api_key,
logging_obj=logging,
headers=headers,
)
if (model == "claude-2") or (model == "claude-instant-1"):
# call anthropic /completion, only use this route for claude-2, claude-instant-1
api_base = (
api_base
or litellm.api_base
or get_secret("ANTHROPIC_API_BASE")
or "https://api.anthropic.com/v1/complete"
)
response = anthropic_text.completion(
model=model,
messages=messages,
api_base=api_base,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=api_key,
logging_obj=logging,
headers=headers,
)
else:
# call /messages
# default route for all anthropic models
api_base = (
api_base
or litellm.api_base
or get_secret("ANTHROPIC_API_BASE")
or "https://api.anthropic.com/v1/messages"
)
response = anthropic.completion(
model=model,
messages=messages,
api_base=api_base,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=api_key,
logging_obj=logging,
headers=headers,
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(