remove init AnthropicClient for completion calls

This commit is contained in:
ishaan-jaff 2023-09-04 09:34:15 -07:00
parent 2dc1c35a05
commit bc065f08df
2 changed files with 119 additions and 136 deletions

View file

@ -1,16 +1,15 @@
import os, json import os
import json
from enum import Enum from enum import Enum
import requests import requests
import time import time
from typing import Callable from typing import Callable
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
class AnthropicConstants(Enum): class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman:" HUMAN_PROMPT = "\n\nHuman:"
AI_PROMPT = "\n\nAssistant:" AI_PROMPT = "\n\nAssistant:"
class AnthropicError(Exception): class AnthropicError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -19,132 +18,120 @@ class AnthropicError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
# makes headers for API call
class AnthropicLLM: def validate_environment(api_key):
def __init__( if api_key is None:
self, encoding, default_max_tokens_to_sample, logging_obj, api_key=None raise ValueError(
): "Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params"
self.encoding = encoding
self.default_max_tokens_to_sample = default_max_tokens_to_sample
self.completion_url = "https://api.anthropic.com/v1/complete"
self.api_key = api_key
self.logging_obj = logging_obj
self.validate_environment(api_key=api_key)
def validate_environment(
self, api_key
): # set up the environment required to run the model
# set the api key
if self.api_key == None:
raise ValueError(
"Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params"
)
self.api_key = api_key
self.headers = {
"accept": "application/json",
"anthropic-version": "2023-06-01",
"content-type": "application/json",
"x-api-key": self.api_key,
}
def completion(
self,
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
optional_params=None,
litellm_params=None,
logger_fn=None,
): # logic for parsing in - calling - parsing out model completion calls
model = model
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}"
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += (
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
)
else:
prompt += (
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
)
else:
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
prompt += f"{AnthropicConstants.AI_PROMPT.value}"
if "max_tokens" in optional_params and optional_params["max_tokens"] != float(
"inf"
):
max_tokens = optional_params["max_tokens"]
else:
max_tokens = self.default_max_tokens_to_sample
data = {
"model": model,
"prompt": prompt,
"max_tokens_to_sample": max_tokens,
**optional_params,
}
## LOGGING
self.logging_obj.pre_call(
input=prompt,
api_key=self.api_key,
additional_args={"complete_input_dict": data},
) )
## COMPLETION CALL headers = {
if "stream" in optional_params and optional_params["stream"] == True: "accept": "application/json",
response = requests.post( "anthropic-version": "2023-06-01",
self.completion_url, "content-type": "application/json",
headers=self.headers, "x-api-key": api_key,
data=json.dumps(data), }
stream=optional_params["stream"], return headers
)
return response.iter_lines() def completion(
else: model: str,
response = requests.post( messages: list,
self.completion_url, headers=self.headers, data=json.dumps(data) model_response: ModelResponse,
) print_verbose: Callable,
## LOGGING encoding,
self.logging_obj.post_call( api_key,
input=prompt, logging_obj,
api_key=self.api_key, optional_params=None,
original_response=response.text, litellm_params=None,
additional_args={"complete_input_dict": data}, logger_fn=None,
) ):
print_verbose(f"raw model_response: {response.text}") headers = validate_environment(api_key)
## RESPONSE OBJECT prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}"
try: for message in messages:
completion_response = response.json() if "role" in message:
except: if message["role"] == "user":
raise AnthropicError(message=response.text, status_code=response.status_code) prompt += (
if "error" in completion_response: f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
raise AnthropicError(
message=str(completion_response["error"]),
status_code=response.status_code,
) )
else: else:
model_response["choices"][0]["message"][ prompt += (
"content" f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
] = completion_response["completion"] )
else:
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
prompt += f"{AnthropicConstants.AI_PROMPT.value}"
if "max_tokens" in optional_params and optional_params["max_tokens"] != float("inf"):
max_tokens = optional_params["max_tokens"]
else:
max_tokens = 256 # required anthropic param, default to 256 if user does not provide an input
data = {
"model": model,
"prompt": prompt,
"max_tokens_to_sample": max_tokens,
**optional_params,
}
## CALCULATING USAGE ## LOGGING
prompt_tokens = len( logging_obj.pre_call(
self.encoding.encode(prompt) input=prompt,
) ##[TODO] use the anthropic tokenizer here api_key=api_key,
completion_tokens = len( additional_args={"complete_input_dict": data},
self.encoding.encode(model_response["choices"][0]["message"]["content"]) )
) ##[TODO] use the anthropic tokenizer here
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post(
"https://api.anthropic.com/v1/complete",
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"],
)
return response.iter_lines()
else:
response = requests.post(
"https://api.anthropic.com/v1/complete", headers=headers, data=json.dumps(data)
)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise AnthropicError(
message=response.text, status_code=response.status_code
)
if "error" in completion_response:
raise AnthropicError(
message=str(completion_response["error"]),
status_code=response.status_code,
)
else:
model_response["choices"][0]["message"]["content"] = completion_response[
"completion"
]
model_response["created"] = time.time() ## CALCULATING USAGE
model_response["model"] = model prompt_tokens = len(
model_response["usage"] = { encoding.encode(prompt)
"prompt_tokens": prompt_tokens, ) ##[TODO] use the anthropic tokenizer here
"completion_tokens": completion_tokens, completion_tokens = len(
"total_tokens": prompt_tokens + completion_tokens, encoding.encode(model_response["choices"][0]["message"]["content"])
} ) ##[TODO] use the anthropic tokenizer here
return model_response
def embedding( model_response["created"] = time.time()
self, model_response["model"] = model
): # logic for parsing in - calling - parsing out model embedding calls model_response["usage"] = {
pass "prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -19,7 +19,7 @@ from litellm.utils import (
read_config_args, read_config_args,
completion_with_fallbacks, completion_with_fallbacks,
) )
from .llms.anthropic import AnthropicLLM from .llms import anthropic
from .llms.huggingface_restapi import HuggingfaceRestAPILLM from .llms.huggingface_restapi import HuggingfaceRestAPILLM
from .llms.baseten import BasetenLLM from .llms.baseten import BasetenLLM
from .llms.ai21 import AI21LLM from .llms.ai21 import AI21LLM
@ -61,7 +61,6 @@ async def acompletion(*args, **kwargs):
@client @client
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(2), reraise=True, retry_error_callback=lambda retry_state: setattr(retry_state.outcome, 'retry_variable', litellm.retry)) # retry call, turn this off by setting `litellm.retry = False`
@timeout( # type: ignore @timeout( # type: ignore
600 600
) ## set timeouts, in case calls hang (e.g. Azure) - default is 600s, override with `force_timeout` ) ## set timeouts, in case calls hang (e.g. Azure) - default is 600s, override with `force_timeout`
@ -79,7 +78,6 @@ def completion(
max_tokens=float("inf"), max_tokens=float("inf"),
presence_penalty=0, presence_penalty=0,
frequency_penalty=0, frequency_penalty=0,
num_beams=1,
logit_bias={}, logit_bias={},
user="", user="",
deployment_id=None, deployment_id=None,
@ -89,6 +87,7 @@ def completion(
api_key=None, api_key=None,
api_version=None, api_version=None,
force_timeout=600, force_timeout=600,
num_beams=1,
logger_fn=None, logger_fn=None,
verbose=False, verbose=False,
azure=False, azure=False,
@ -407,13 +406,7 @@ def completion(
anthropic_key = ( anthropic_key = (
api_key or litellm.anthropic_key or os.environ.get("ANTHROPIC_API_KEY") api_key or litellm.anthropic_key or os.environ.get("ANTHROPIC_API_KEY")
) )
anthropic_client = AnthropicLLM( model_response = anthropic.completion(
encoding=encoding,
default_max_tokens_to_sample=litellm.max_tokens,
api_key=anthropic_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit anthropic's requirements
)
model_response = anthropic_client.completion(
model=model, model=model,
messages=messages, messages=messages,
model_response=model_response, model_response=model_response,
@ -421,6 +414,9 @@ def completion(
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=anthropic_key,
logging_obj=logging,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object, # don't try to access stream object,