diff --git a/litellm/__init__.py b/litellm/__init__.py index 0d28d262ee..87be1d002f 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -846,6 +846,7 @@ class LlmProviders(str, Enum): COHERE_CHAT = "cohere_chat" CLARIFAI = "clarifai" ANTHROPIC = "anthropic" + ANTHROPIC_TEXT = "anthropic_text" REPLICATE = "replicate" HUGGINGFACE = "huggingface" TOGETHER_AI = "together_ai" @@ -1060,7 +1061,7 @@ from .llms.anthropic.experimental_pass_through.transformation import ( AnthropicExperimentalPassThroughConfig, ) from .llms.groq.stt.transformation import GroqSTTConfig -from .llms.anthropic.completion import AnthropicTextConfig +from .llms.anthropic.completion.transformation import AnthropicTextConfig from .llms.databricks.chat.transformation import DatabricksConfig from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig from .llms.predibase import PredibaseConfig diff --git a/litellm/constants.py b/litellm/constants.py index c0aa2a3690..1fb97e07fc 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -12,6 +12,7 @@ LITELLM_CHAT_PROVIDERS = [ "cohere_chat", "clarifai", "anthropic", + "anthropic_text", "replicate", "huggingface", "together_ai", diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py index fea931491f..68992361ad 100644 --- a/litellm/litellm_core_utils/get_llm_provider_logic.py +++ b/litellm/litellm_core_utils/get_llm_provider_logic.py @@ -52,6 +52,39 @@ def handle_cohere_chat_model_custom_llm_provider( return model, custom_llm_provider +def handle_anthropic_text_model_custom_llm_provider( + model: str, custom_llm_provider: Optional[str] = None +) -> Tuple[str, Optional[str]]: + """ + if user sets model = "anthropic/claude-2" -> use custom_llm_provider = "anthropic_text" + + Args: + model: + custom_llm_provider: + + Returns: + model, custom_llm_provider + """ + + if custom_llm_provider: + if ( + custom_llm_provider == "anthropic" + and litellm.AnthropicTextConfig._is_anthropic_text_model(model) + ): + return model, "anthropic_text" + + if "/" in model: + _custom_llm_provider, _model = model.split("/", 1) + if ( + _custom_llm_provider + and _custom_llm_provider == "anthropic" + and litellm.AnthropicTextConfig._is_anthropic_text_model(_model) + ): + return _model, "anthropic_text" + + return model, custom_llm_provider + + def get_llm_provider( # noqa: PLR0915 model: str, custom_llm_provider: Optional[str] = None, @@ -92,6 +125,10 @@ def get_llm_provider( # noqa: PLR0915 model, custom_llm_provider ) + model, custom_llm_provider = handle_anthropic_text_model_custom_llm_provider( + model, custom_llm_provider + ) + if custom_llm_provider: if ( model.split("/")[0] == custom_llm_provider @@ -210,7 +247,10 @@ def get_llm_provider( # noqa: PLR0915 custom_llm_provider = "text-completion-openai" ## anthropic elif model in litellm.anthropic_models: - custom_llm_provider = "anthropic" + if litellm.AnthropicTextConfig._is_anthropic_text_model(model): + custom_llm_provider = "anthropic_text" + else: + custom_llm_provider = "anthropic" ## cohere elif model in litellm.cohere_models or model in litellm.cohere_embedding_models: custom_llm_provider = "cohere" @@ -531,7 +571,9 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915 ) elif custom_llm_provider == "galadriel": api_base = ( - api_base or get_secret("GALADRIEL_API_BASE") or "https://api.galadriel.com/v1" + api_base + or get_secret("GALADRIEL_API_BASE") + or "https://api.galadriel.com/v1" ) # type: ignore dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY") if api_base is not None and not isinstance(api_base, str): diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 31550ae353..7b1973d9cf 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -223,40 +223,6 @@ class CustomStreamWrapper: self.holding_chunk = "" return hold, curr_chunk - def handle_anthropic_text_chunk(self, chunk): - """ - For old anthropic models - claude-1, claude-2. - - Claude-3 is handled from within Anthropic.py VIA ModelResponseIterator() - """ - str_line = chunk - if isinstance(chunk, bytes): # Handle binary data - str_line = chunk.decode("utf-8") # Convert bytes to string - text = "" - is_finished = False - finish_reason = None - if str_line.startswith("data:"): - data_json = json.loads(str_line[5:]) - type_chunk = data_json.get("type", None) - if type_chunk == "completion": - text = data_json.get("completion") - finish_reason = data_json.get("stop_reason") - if finish_reason is not None: - is_finished = True - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } - elif "error" in str_line: - raise ValueError(f"Unable to parse response. Original response: {str_line}") - else: - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } - def handle_predibase_chunk(self, chunk): try: if not isinstance(chunk, str): @@ -1005,14 +971,6 @@ class CustomStreamWrapper: setattr(model_response, key, value) response_obj = anthropic_response_obj - elif ( - self.custom_llm_provider - and self.custom_llm_provider == "anthropic_text" - ): - response_obj = self.handle_anthropic_text_chunk(chunk) - completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - self.received_finish_reason = response_obj["finish_reason"] elif self.model == "replicate" or self.custom_llm_provider == "replicate": response_obj = self.handle_replicate_chunk(chunk) completion_obj["content"] = response_obj["text"] diff --git a/litellm/llms/anthropic/completion.py b/litellm/llms/anthropic/completion.py deleted file mode 100644 index dc06401d6d..0000000000 --- a/litellm/llms/anthropic/completion.py +++ /dev/null @@ -1,362 +0,0 @@ -""" -Translation logic for anthropic's `/v1/complete` endpoint -""" - -import json -import os -import time -import types -from enum import Enum -from typing import Callable, Optional - -import httpx -import requests - -import litellm -from litellm.llms.custom_httpx.http_handler import ( - AsyncHTTPHandler, - HTTPHandler, - get_async_httpx_client, -) -from litellm.utils import CustomStreamWrapper, ModelResponse, Usage - -from ..base import BaseLLM -from ..prompt_templates.factory import custom_prompt, prompt_factory - - -class AnthropicConstants(Enum): - HUMAN_PROMPT = "\n\nHuman: " - AI_PROMPT = "\n\nAssistant: " - - -class AnthropicError(Exception): - def __init__(self, status_code, message): - self.status_code = status_code - self.message = message - self.request = httpx.Request( - method="POST", url="https://api.anthropic.com/v1/complete" - ) - self.response = httpx.Response(status_code=status_code, request=self.request) - super().__init__( - self.message - ) # Call the base class constructor with the parameters it needs - - -class AnthropicTextConfig: - """ - Reference: https://docs.anthropic.com/claude/reference/complete_post - - to pass metadata to anthropic, it's {"user_id": "any-relevant-information"} - """ - - max_tokens_to_sample: Optional[int] = ( - litellm.max_tokens - ) # anthropic requires a default - stop_sequences: Optional[list] = None - temperature: Optional[int] = None - top_p: Optional[int] = None - top_k: Optional[int] = None - metadata: Optional[dict] = None - - def __init__( - self, - max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default - stop_sequences: Optional[list] = None, - temperature: Optional[int] = None, - top_p: Optional[int] = None, - top_k: Optional[int] = None, - metadata: Optional[dict] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - -# makes headers for API call -def validate_environment(api_key, user_headers): - if api_key is None: - raise ValueError( - "Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params" - ) - headers = { - "accept": "application/json", - "anthropic-version": "2023-06-01", - "content-type": "application/json", - "x-api-key": api_key, - } - if user_headers is not None and isinstance(user_headers, dict): - headers = {**headers, **user_headers} - return headers - - -class AnthropicTextCompletion(BaseLLM): - def __init__(self) -> None: - super().__init__() - - def _process_response( - self, model_response: ModelResponse, response, encoding, prompt: str, model: str - ): - ## RESPONSE OBJECT - try: - completion_response = response.json() - except Exception: - raise AnthropicError( - message=response.text, status_code=response.status_code - ) - if "error" in completion_response: - raise AnthropicError( - message=str(completion_response["error"]), - status_code=response.status_code, - ) - else: - if len(completion_response["completion"]) > 0: - model_response.choices[0].message.content = completion_response[ # type: ignore - "completion" - ] - model_response.choices[0].finish_reason = completion_response["stop_reason"] - - ## CALCULATING USAGE - prompt_tokens = len( - encoding.encode(prompt) - ) ##[TODO] use the anthropic tokenizer here - completion_tokens = len( - encoding.encode(model_response["choices"][0]["message"].get("content", "")) - ) ##[TODO] use the anthropic tokenizer here - - model_response.created = int(time.time()) - model_response.model = model - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - - setattr(model_response, "usage", usage) - - return model_response - - async def async_completion( - self, - model: str, - model_response: ModelResponse, - api_base: str, - logging_obj, - encoding, - headers: dict, - data: dict, - client=None, - ): - if client is None: - client = get_async_httpx_client( - llm_provider=litellm.LlmProviders.ANTHROPIC, - params={"timeout": httpx.Timeout(timeout=600.0, connect=5.0)}, - ) - - response = await client.post(api_base, headers=headers, data=json.dumps(data)) - - if response.status_code != 200: - raise AnthropicError( - status_code=response.status_code, message=response.text - ) - - ## LOGGING - logging_obj.post_call( - input=data["prompt"], - api_key=headers.get("x-api-key"), - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) - - response = self._process_response( - model_response=model_response, - response=response, - encoding=encoding, - prompt=data["prompt"], - model=model, - ) - return response - - async def async_streaming( - self, - model: str, - api_base: str, - logging_obj, - headers: dict, - data: Optional[dict], - client=None, - ): - if client is None: - client = get_async_httpx_client( - llm_provider=litellm.LlmProviders.ANTHROPIC, - params={"timeout": httpx.Timeout(timeout=600.0, connect=5.0)}, - ) - - response = await client.post(api_base, headers=headers, data=json.dumps(data)) - - if response.status_code != 200: - raise AnthropicError( - status_code=response.status_code, message=response.text - ) - - completion_stream = response.aiter_lines() - - streamwrapper = CustomStreamWrapper( - completion_stream=completion_stream, - model=model, - custom_llm_provider="anthropic_text", - logging_obj=logging_obj, - ) - return streamwrapper - - def completion( - self, - model: str, - messages: list, - api_base: str, - acompletion: str, - custom_prompt_dict: dict, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - api_key, - logging_obj, - optional_params: dict, - litellm_params=None, - logger_fn=None, - headers={}, - client=None, - ): - headers = validate_environment(api_key, headers) - if model in custom_prompt_dict: - # check if the model has a registered custom prompt - model_prompt_details = custom_prompt_dict[model] - prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages, - ) - else: - prompt = prompt_factory( - model=model, messages=messages, custom_llm_provider="anthropic" - ) - - ## Load Config - config = litellm.AnthropicTextConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - data = { - "model": model, - "prompt": prompt, - **optional_params, - } - - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=api_key, - additional_args={ - "complete_input_dict": data, - "api_base": api_base, - "headers": headers, - }, - ) - - ## COMPLETION CALL - if "stream" in optional_params and optional_params["stream"] is True: - if acompletion is True: - return self.async_streaming( - model=model, - api_base=api_base, - logging_obj=logging_obj, - headers=headers, - data=data, - client=None, - ) - - if client is None: - client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) - - response = client.post( - api_base, - headers=headers, - data=json.dumps(data), - # stream=optional_params["stream"], - ) - - if response.status_code != 200: - raise AnthropicError( - status_code=response.status_code, message=response.text - ) - completion_stream = response.iter_lines() - stream_response = CustomStreamWrapper( - completion_stream=completion_stream, - model=model, - custom_llm_provider="anthropic_text", - logging_obj=logging_obj, - ) - return stream_response - elif acompletion is True: - return self.async_completion( - model=model, - model_response=model_response, - api_base=api_base, - logging_obj=logging_obj, - encoding=encoding, - headers=headers, - data=data, - client=client, - ) - else: - if client is None: - client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) - response = client.post(api_base, headers=headers, data=json.dumps(data)) - if response.status_code != 200: - raise AnthropicError( - status_code=response.status_code, message=response.text - ) - - ## LOGGING - logging_obj.post_call( - input=prompt, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) - print_verbose(f"raw model_response: {response.text}") - - response = self._process_response( - model_response=model_response, - response=response, - encoding=encoding, - prompt=data["prompt"], - model=model, - ) - return response - - def embedding(self): - # logic for parsing in - calling - parsing out model embedding calls - pass diff --git a/litellm/llms/anthropic/completion/handler.py b/litellm/llms/anthropic/completion/handler.py new file mode 100644 index 0000000000..f1c8be7bbc --- /dev/null +++ b/litellm/llms/anthropic/completion/handler.py @@ -0,0 +1,5 @@ +""" +Anthropic /complete API - uses `llm_http_handler.py` to make httpx requests + +Request/Response transformation is handled in `transformation.py` +""" diff --git a/litellm/llms/anthropic/completion/transformation.py b/litellm/llms/anthropic/completion/transformation.py new file mode 100644 index 0000000000..7436327324 --- /dev/null +++ b/litellm/llms/anthropic/completion/transformation.py @@ -0,0 +1,307 @@ +""" +Translation logic for anthropic's `/v1/complete` endpoint + +Litellm provider slug: `anthropic_text/` +""" + +import json +import time +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union + +import httpx + +import litellm +from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator +from litellm.llms.base_llm.transformation import ( + BaseConfig, + BaseLLMException, + LiteLLMLoggingObj, +) +from litellm.llms.prompt_templates.factory import custom_prompt, prompt_factory +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ( + ChatCompletionToolCallChunk, + ChatCompletionUsageBlock, + GenericStreamingChunk, + ModelResponse, + Usage, +) + + +class AnthropicTextError(BaseLLMException): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url="https://api.anthropic.com/v1/complete" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + message=self.message, + status_code=self.status_code, + request=self.request, + response=self.response, + ) # Call the base class constructor with the parameters it needs + + +class AnthropicTextConfig(BaseConfig): + """ + Reference: https://docs.anthropic.com/claude/reference/complete_post + + to pass metadata to anthropic, it's {"user_id": "any-relevant-information"} + """ + + max_tokens_to_sample: Optional[int] = ( + litellm.max_tokens + ) # anthropic requires a default + stop_sequences: Optional[list] = None + temperature: Optional[int] = None + top_p: Optional[int] = None + top_k: Optional[int] = None + metadata: Optional[dict] = None + + def __init__( + self, + max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default + stop_sequences: Optional[list] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + top_k: Optional[int] = None, + metadata: Optional[dict] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + # makes headers for API call + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + ) -> dict: + if api_key is None: + raise ValueError( + "Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params" + ) + _headers = { + "accept": "application/json", + "anthropic-version": "2023-06-01", + "content-type": "application/json", + "x-api-key": api_key, + } + headers.update(_headers) + return headers + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + prompt = self._get_anthropic_text_prompt_from_messages( + messages=messages, model=model + ) + ## Load Config + config = litellm.AnthropicTextConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + data = { + "model": model, + "prompt": prompt, + **optional_params, + } + + return data + + def get_supported_openai_params(self, model: str): + """ + Anthropic /complete API Ref: https://docs.anthropic.com/en/api/complete + """ + return [ + "stream", + "max_tokens", + "max_completion_tokens", + "stop", + "temperature", + "top_p", + "extra_headers", + "user", + ] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + """ + Follows the same logic as the AnthropicConfig.map_openai_params method (which is the Anthropic /messages API) + + Note: the only difference is in the get supported openai params method between the AnthropicConfig and AnthropicTextConfig + API Ref: https://docs.anthropic.com/en/api/complete + """ + for param, value in non_default_params.items(): + if param == "max_tokens": + optional_params["max_tokens_to_sample"] = value + if param == "max_completion_tokens": + optional_params["max_tokens_to_sample"] = value + if param == "stream" and value is True: + optional_params["stream"] = value + if param == "stop" and (isinstance(value, str) or isinstance(value, list)): + _value = litellm.AnthropicConfig()._map_stop_sequences(value) + if _value is not None: + optional_params["stop_sequences"] = _value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "user": + optional_params["metadata"] = {"user_id": value} + + return optional_params + + def transform_response( + self, + model: str, + raw_response: httpx.Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + encoding: str, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + try: + completion_response = raw_response.json() + except Exception: + raise AnthropicTextError( + message=raw_response.text, status_code=raw_response.status_code + ) + prompt = self._get_anthropic_text_prompt_from_messages( + messages=messages, model=model + ) + if "error" in completion_response: + raise AnthropicTextError( + message=str(completion_response["error"]), + status_code=raw_response.status_code, + ) + else: + if len(completion_response["completion"]) > 0: + model_response.choices[0].message.content = completion_response[ # type: ignore + "completion" + ] + model_response.choices[0].finish_reason = completion_response["stop_reason"] + + ## CALCULATING USAGE + prompt_tokens = len( + encoding.encode(prompt) + ) ##[TODO] use the anthropic tokenizer here + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"].get("content", "")) + ) ##[TODO] use the anthropic tokenizer here + + model_response.created = int(time.time()) + model_response.model = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + + setattr(model_response, "usage", usage) + return model_response + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers] + ) -> BaseLLMException: + return AnthropicTextError( + status_code=status_code, + message=error_message, + ) + + @staticmethod + def _is_anthropic_text_model(model: str) -> bool: + return model == "claude-2" or model == "claude-instant-1" + + def _get_anthropic_text_prompt_from_messages( + self, messages: List[AllMessageValues], model: str + ) -> str: + custom_prompt_dict = litellm.custom_prompt_dict + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, + ) + else: + prompt = prompt_factory( + model=model, messages=messages, custom_llm_provider="anthropic" + ) + + return str(prompt) + + def _transform_messages( + self, messages: List[AllMessageValues] + ) -> List[AllMessageValues]: + "Not required" + raise NotImplementedError + + def get_model_response_iterator( + self, + streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], + sync_stream: bool, + json_mode: Optional[bool] = False, + ): + return AnthropicTextCompletionResponseIterator( + streaming_response=streaming_response, + sync_stream=sync_stream, + json_mode=json_mode, + ) + + +class AnthropicTextCompletionResponseIterator(BaseModelResponseIterator): + def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: + try: + text = "" + tool_use: Optional[ChatCompletionToolCallChunk] = None + is_finished = False + finish_reason = "" + usage: Optional[ChatCompletionUsageBlock] = None + provider_specific_fields = None + index = int(chunk.get("index", 0)) + _chunk_text = chunk.get("completion", None) + if _chunk_text is not None and isinstance(_chunk_text, str): + text = _chunk_text + finish_reason = chunk.get("stop_reason", None) + if finish_reason is not None: + is_finished = True + returned_chunk = GenericStreamingChunk( + text=text, + tool_use=tool_use, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=index, + provider_specific_fields=provider_specific_fields, + ) + + return returned_chunk + + except json.JSONDecodeError: + raise ValueError(f"Failed to decode JSON from chunk: {chunk}") diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index c218377e51..490a39c29f 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -2839,7 +2839,7 @@ def prompt_factory( if custom_llm_provider == "ollama": return ollama_pt(model=model, messages=messages) elif custom_llm_provider == "anthropic": - if model == "claude-instant-1" or model == "claude-2": + if litellm.AnthropicTextConfig._is_anthropic_text_model(model): return anthropic_pt(messages=messages) return anthropic_messages_pt( messages=messages, model=model, llm_provider=custom_llm_provider diff --git a/litellm/main.py b/litellm/main.py index b7fd631ec6..10142bdb25 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -99,7 +99,6 @@ from .llms import ( ) from .llms.ai21 import completion as ai21 from .llms.anthropic.chat import AnthropicChatCompletion -from .llms.anthropic.completion import AnthropicTextCompletion from .llms.azure.audio_transcriptions import AzureAudioTranscription from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params from .llms.azure.chat.o1_handler import AzureOpenAIO1ChatCompletion @@ -204,7 +203,6 @@ together_ai_text_completions = TogetherAITextCompletion() azure_ai_chat_completions = AzureAIChatCompletion() azure_ai_embedding = AzureAIEmbedding() anthropic_chat_completions = AnthropicChatCompletion() -anthropic_text_completions = AnthropicTextCompletion() azure_chat_completions = AzureChatCompletion() azure_o1_chat_completions = AzureOpenAIO1ChatCompletion() azure_text_completions = AzureTextCompletion() @@ -464,6 +462,7 @@ async def acompletion( or custom_llm_provider == "sagemaker" or custom_llm_provider == "sagemaker_chat" or custom_llm_provider == "anthropic" + or custom_llm_provider == "anthropic_text" or custom_llm_provider == "predibase" or custom_llm_provider == "bedrock" or custom_llm_provider == "databricks" @@ -1705,6 +1704,41 @@ def completion( # type: ignore # noqa: PLR0915 api_key=clarifai_key, logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements ) + elif custom_llm_provider == "anthropic_text": + api_key = ( + api_key + or litellm.anthropic_key + or litellm.api_key + or os.environ.get("ANTHROPIC_API_KEY") + ) + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + api_base = ( + api_base + or litellm.api_base + or get_secret("ANTHROPIC_API_BASE") + or get_secret("ANTHROPIC_BASE_URL") + or "https://api.anthropic.com/v1/complete" + ) + + if api_base is not None and not api_base.endswith("/v1/complete"): + api_base += "/v1/complete" + + response = base_llm_http_handler.completion( + model=model, + stream=stream, + messages=messages, + acompletion=acompletion, + api_base=api_base, + model_response=model_response, + optional_params=optional_params, + litellm_params=litellm_params, + custom_llm_provider="anthropic_text", + timeout=timeout, + headers=headers, + encoding=encoding, + api_key=api_key, + logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements + ) elif custom_llm_provider == "anthropic": api_key = ( api_key @@ -1713,69 +1747,38 @@ def completion( # type: ignore # noqa: PLR0915 or os.environ.get("ANTHROPIC_API_KEY") ) custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + # call /messages + # default route for all anthropic models + api_base = ( + api_base + or litellm.api_base + or get_secret("ANTHROPIC_API_BASE") + or get_secret("ANTHROPIC_BASE_URL") + or "https://api.anthropic.com/v1/messages" + ) - if (model == "claude-2") or (model == "claude-instant-1"): - # call anthropic /completion, only use this route for claude-2, claude-instant-1 - api_base = ( - api_base - or litellm.api_base - or get_secret("ANTHROPIC_API_BASE") - or get_secret("ANTHROPIC_BASE_URL") - or "https://api.anthropic.com/v1/complete" - ) + if api_base is not None and not api_base.endswith("/v1/messages"): + api_base += "/v1/messages" - if api_base is not None and not api_base.endswith("/v1/complete"): - api_base += "/v1/complete" - - response = anthropic_text_completions.completion( - model=model, - messages=messages, - api_base=api_base, - acompletion=acompletion, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=api_key, - logging_obj=logging, - headers=headers, - ) - else: - # call /messages - # default route for all anthropic models - api_base = ( - api_base - or litellm.api_base - or get_secret("ANTHROPIC_API_BASE") - or get_secret("ANTHROPIC_BASE_URL") - or "https://api.anthropic.com/v1/messages" - ) - - if api_base is not None and not api_base.endswith("/v1/messages"): - api_base += "/v1/messages" - - response = anthropic_chat_completions.completion( - model=model, - messages=messages, - api_base=api_base, - acompletion=acompletion, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=api_key, - logging_obj=logging, - headers=headers, - timeout=timeout, - client=client, - custom_llm_provider=custom_llm_provider, - ) + response = anthropic_chat_completions.completion( + model=model, + messages=messages, + api_base=api_base, + acompletion=acompletion, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, # for calculating input/output tokens + api_key=api_key, + logging_obj=logging, + headers=headers, + timeout=timeout, + client=client, + custom_llm_provider=custom_llm_provider, + ) if optional_params.get("stream", False) or acompletion is True: ## LOGGING logging.post_call( diff --git a/litellm/utils.py b/litellm/utils.py index ebf8115f86..5321357a87 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2830,6 +2830,32 @@ def get_optional_params( # noqa: PLR0915 else False ), ) + elif custom_llm_provider == "anthropic_text": + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + optional_params = litellm.AnthropicTextConfig().map_openai_params( + model=model, + non_default_params=non_default_params, + optional_params=optional_params, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.AnthropicTextConfig().map_openai_params( + model=model, + non_default_params=non_default_params, + optional_params=optional_params, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), + ) + elif custom_llm_provider == "cohere": ## check if unsupported param passed in supported_params = get_supported_openai_params( @@ -4208,7 +4234,7 @@ def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]): if llm_provider == "openai" or llm_provider == "text-completion-openai": api_key = api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") # anthropic - elif llm_provider == "anthropic": + elif llm_provider == "anthropic" or llm_provider == "anthropic_text": api_key = api_key or litellm.anthropic_key or get_secret("ANTHROPIC_API_KEY") # ai21 elif llm_provider == "ai21": @@ -6251,6 +6277,8 @@ class ProviderConfigManager: return litellm.ClarifaiConfig() elif litellm.LlmProviders.ANTHROPIC == provider: return litellm.AnthropicConfig() + elif litellm.LlmProviders.ANTHROPIC_TEXT == provider: + return litellm.AnthropicTextConfig() elif litellm.LlmProviders.VERTEX_AI == provider: if "claude" in model: return litellm.VertexAIAnthropicConfig() diff --git a/tests/llm_translation/test_anthropic_text_completion.py b/tests/llm_translation/test_anthropic_text_completion.py new file mode 100644 index 0000000000..bd0aefda18 --- /dev/null +++ b/tests/llm_translation/test_anthropic_text_completion.py @@ -0,0 +1,73 @@ +import asyncio +import os +from re import T +import sys +import traceback + +from dotenv import load_dotenv + +import litellm.types +import litellm.types.utils +from litellm.llms.anthropic.chat import ModelResponseIterator + +load_dotenv() +import io +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +from typing import Optional +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model", ["claude-2", "anthropic/claude-2"]) +async def test_acompletion_claude2(model): + try: + litellm.set_verbose = True + messages = [ + { + "role": "system", + "content": "Your goal is generate a joke on the topic user gives.", + }, + {"role": "user", "content": "Generate a 3 liner joke for me"}, + ] + # test without max-tokens + response = await litellm.acompletion(model=model, messages=messages) + # Add any assertions here to check the response + print(response) + print(response.usage) + print(response.usage.completion_tokens) + print(response["usage"]["completion_tokens"]) + # print("new cost tracking") + except litellm.InternalServerError: + pytest.skip("model is overloaded.") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +@pytest.mark.asyncio +async def test_acompletion_claude2_stream(): + try: + litellm.set_verbose = False + messages = [ + { + "role": "system", + "content": "Your goal is generate a joke on the topic user gives.", + }, + {"role": "user", "content": "Generate a 3 liner joke for me"}, + ] + # test without max-tokens + response = await litellm.acompletion( + model="anthropic_text/claude-2", + messages=messages, + stream=True, + max_tokens=10, + ) + async for chunk in response: + print(chunk) + except Exception as e: + pytest.fail(f"Error occurred: {e}") diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index 9481337fbf..0f8addf775 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -1126,32 +1126,6 @@ def test_completion_mistral_api_modified_input(): pytest.fail(f"Error occurred: {e}") -@pytest.mark.asyncio -async def test_acompletion_claude2_1(): - try: - litellm.set_verbose = True - print("claude2.1 test request") - messages = [ - { - "role": "system", - "content": "Your goal is generate a joke on the topic user gives.", - }, - {"role": "user", "content": "Generate a 3 liner joke for me"}, - ] - # test without max-tokens - response = await litellm.acompletion(model="claude-2.1", messages=messages) - # Add any assertions here to check the response - print(response) - print(response.usage) - print(response.usage.completion_tokens) - print(response["usage"]["completion_tokens"]) - # print("new cost tracking") - except litellm.InternalServerError: - pytest.skip("model is overloaded.") - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - # def test_completion_oobabooga(): # try: # response = completion(