diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index 904c81819..021ec4a73 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -1,16 +1,15 @@ -import os, json +import os +import json from enum import Enum import requests import time from typing import Callable from litellm.utils import ModelResponse - class AnthropicConstants(Enum): HUMAN_PROMPT = "\n\nHuman:" AI_PROMPT = "\n\nAssistant:" - class AnthropicError(Exception): def __init__(self, status_code, message): self.status_code = status_code @@ -19,132 +18,120 @@ class AnthropicError(Exception): self.message ) # Call the base class constructor with the parameters it needs - -class AnthropicLLM: - def __init__( - self, encoding, default_max_tokens_to_sample, logging_obj, api_key=None - ): - self.encoding = encoding - self.default_max_tokens_to_sample = default_max_tokens_to_sample - self.completion_url = "https://api.anthropic.com/v1/complete" - self.api_key = api_key - self.logging_obj = logging_obj - self.validate_environment(api_key=api_key) - - def validate_environment( - self, api_key - ): # set up the environment required to run the model - # set the api key - if self.api_key == None: - raise ValueError( - "Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params" - ) - self.api_key = api_key - self.headers = { - "accept": "application/json", - "anthropic-version": "2023-06-01", - "content-type": "application/json", - "x-api-key": self.api_key, - } - - def completion( - self, - model: str, - messages: list, - model_response: ModelResponse, - print_verbose: Callable, - optional_params=None, - litellm_params=None, - logger_fn=None, - ): # logic for parsing in - calling - parsing out model completion calls - model = model - prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" - for message in messages: - if "role" in message: - if message["role"] == "user": - prompt += ( - f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}" - ) - else: - prompt += ( - f"{AnthropicConstants.AI_PROMPT.value}{message['content']}" - ) - else: - prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}" - prompt += f"{AnthropicConstants.AI_PROMPT.value}" - if "max_tokens" in optional_params and optional_params["max_tokens"] != float( - "inf" - ): - max_tokens = optional_params["max_tokens"] - else: - max_tokens = self.default_max_tokens_to_sample - data = { - "model": model, - "prompt": prompt, - "max_tokens_to_sample": max_tokens, - **optional_params, - } - - ## LOGGING - self.logging_obj.pre_call( - input=prompt, - api_key=self.api_key, - additional_args={"complete_input_dict": data}, +# makes headers for API call +def validate_environment(api_key): + if api_key is None: + raise ValueError( + "Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params" ) - ## COMPLETION CALL - if "stream" in optional_params and optional_params["stream"] == True: - response = requests.post( - self.completion_url, - headers=self.headers, - data=json.dumps(data), - stream=optional_params["stream"], - ) - return response.iter_lines() - else: - response = requests.post( - self.completion_url, headers=self.headers, data=json.dumps(data) - ) - ## LOGGING - self.logging_obj.post_call( - input=prompt, - api_key=self.api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) - print_verbose(f"raw model_response: {response.text}") - ## RESPONSE OBJECT - try: - completion_response = response.json() - except: - raise AnthropicError(message=response.text, status_code=response.status_code) - if "error" in completion_response: - raise AnthropicError( - message=str(completion_response["error"]), - status_code=response.status_code, + headers = { + "accept": "application/json", + "anthropic-version": "2023-06-01", + "content-type": "application/json", + "x-api-key": api_key, + } + return headers + +def completion( + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params=None, + litellm_params=None, + logger_fn=None, +): + headers = validate_environment(api_key) + prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" + for message in messages: + if "role" in message: + if message["role"] == "user": + prompt += ( + f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}" ) else: - model_response["choices"][0]["message"][ - "content" - ] = completion_response["completion"] + prompt += ( + f"{AnthropicConstants.AI_PROMPT.value}{message['content']}" + ) + else: + prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}" + prompt += f"{AnthropicConstants.AI_PROMPT.value}" + if "max_tokens" in optional_params and optional_params["max_tokens"] != float("inf"): + max_tokens = optional_params["max_tokens"] + else: + max_tokens = 256 # required anthropic param, default to 256 if user does not provide an input + data = { + "model": model, + "prompt": prompt, + "max_tokens_to_sample": max_tokens, + **optional_params, + } - ## CALCULATING USAGE - prompt_tokens = len( - self.encoding.encode(prompt) - ) ##[TODO] use the anthropic tokenizer here - completion_tokens = len( - self.encoding.encode(model_response["choices"][0]["message"]["content"]) - ) ##[TODO] use the anthropic tokenizer here + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) + + ## COMPLETION CALL + if "stream" in optional_params and optional_params["stream"] == True: + response = requests.post( + "https://api.anthropic.com/v1/complete", + headers=headers, + data=json.dumps(data), + stream=optional_params["stream"], + ) + return response.iter_lines() + else: + response = requests.post( + "https://api.anthropic.com/v1/complete", headers=headers, data=json.dumps(data) + ) + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + try: + completion_response = response.json() + except: + raise AnthropicError( + message=response.text, status_code=response.status_code + ) + if "error" in completion_response: + raise AnthropicError( + message=str(completion_response["error"]), + status_code=response.status_code, + ) + else: + model_response["choices"][0]["message"]["content"] = completion_response[ + "completion" + ] - model_response["created"] = time.time() - model_response["model"] = model - model_response["usage"] = { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens, - } - return model_response + ## CALCULATING USAGE + prompt_tokens = len( + encoding.encode(prompt) + ) ##[TODO] use the anthropic tokenizer here + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"]["content"]) + ) ##[TODO] use the anthropic tokenizer here - def embedding( - self, - ): # logic for parsing in - calling - parsing out model embedding calls - pass + model_response["created"] = time.time() + model_response["model"] = model + model_response["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + return model_response + +def embedding(): + # logic for parsing in - calling - parsing out model embedding calls + pass diff --git a/litellm/main.py b/litellm/main.py index 09e409293..9d453af46 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -19,7 +19,7 @@ from litellm.utils import ( read_config_args, completion_with_fallbacks, ) -from .llms.anthropic import AnthropicLLM +from .llms import anthropic from .llms.huggingface_restapi import HuggingfaceRestAPILLM from .llms.baseten import BasetenLLM from .llms.ai21 import AI21LLM @@ -61,7 +61,6 @@ async def acompletion(*args, **kwargs): @client -# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(2), reraise=True, retry_error_callback=lambda retry_state: setattr(retry_state.outcome, 'retry_variable', litellm.retry)) # retry call, turn this off by setting `litellm.retry = False` @timeout( # type: ignore 600 ) ## set timeouts, in case calls hang (e.g. Azure) - default is 600s, override with `force_timeout` @@ -79,7 +78,6 @@ def completion( max_tokens=float("inf"), presence_penalty=0, frequency_penalty=0, - num_beams=1, logit_bias={}, user="", deployment_id=None, @@ -89,6 +87,7 @@ def completion( api_key=None, api_version=None, force_timeout=600, + num_beams=1, logger_fn=None, verbose=False, azure=False, @@ -407,13 +406,7 @@ def completion( anthropic_key = ( api_key or litellm.anthropic_key or os.environ.get("ANTHROPIC_API_KEY") ) - anthropic_client = AnthropicLLM( - encoding=encoding, - default_max_tokens_to_sample=litellm.max_tokens, - api_key=anthropic_key, - logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit anthropic's requirements - ) - model_response = anthropic_client.completion( + model_response = anthropic.completion( model=model, messages=messages, model_response=model_response, @@ -421,6 +414,9 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, + encoding=encoding, # for calculating input/output tokens + api_key=anthropic_key, + logging_obj=logging, ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object,