diff --git a/litellm/llms/anthropic_text.py b/litellm/llms/anthropic_text.py index c9a9adfc26..77c8e04d9b 100644 --- a/litellm/llms/anthropic_text.py +++ b/litellm/llms/anthropic_text.py @@ -8,6 +8,8 @@ from litellm.utils import ModelResponse, Usage, CustomStreamWrapper import litellm from .prompt_templates.factory import prompt_factory, custom_prompt import httpx +from .base import BaseLLM +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler class AnthropicConstants(Enum): @@ -94,98 +96,13 @@ def validate_environment(api_key, user_headers): return headers -def completion( - model: str, - messages: list, - api_base: str, - custom_prompt_dict: dict, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - api_key, - logging_obj, - optional_params=None, - litellm_params=None, - logger_fn=None, - headers={}, -): - headers = validate_environment(api_key, headers) - if model in custom_prompt_dict: - # check if the model has a registered custom prompt - model_prompt_details = custom_prompt_dict[model] - prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages, - ) - else: - prompt = prompt_factory( - model=model, messages=messages, custom_llm_provider="anthropic" - ) +class AnthropicTextCompletion(BaseLLM): + def __init__(self) -> None: + super().__init__() - ## Load Config - config = litellm.AnthropicTextConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - data = { - "model": model, - "prompt": prompt, - **optional_params, - } - - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=api_key, - additional_args={ - "complete_input_dict": data, - "api_base": api_base, - "headers": headers, - }, - ) - - ## COMPLETION CALL - if "stream" in optional_params and optional_params["stream"] == True: - response = requests.post( - api_base, - headers=headers, - data=json.dumps(data), - stream=optional_params["stream"], - ) - - if response.status_code != 200: - raise AnthropicError( - status_code=response.status_code, message=response.text - ) - completion_stream = response.iter_lines() - stream_response = CustomStreamWrapper( - completion_stream=completion_stream, - model=model, - custom_llm_provider="anthropic", - logging_obj=logging_obj, - ) - return stream_response - - else: - response = requests.post(api_base, headers=headers, data=json.dumps(data)) - if response.status_code != 200: - raise AnthropicError( - status_code=response.status_code, message=response.text - ) - - ## LOGGING - logging_obj.post_call( - input=prompt, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) - print_verbose(f"raw model_response: {response.text}") + def process_response( + self, model_response: ModelResponse, response, encoding, prompt: str, model: str + ): ## RESPONSE OBJECT try: completion_response = response.json() @@ -221,9 +138,206 @@ def completion( total_tokens=prompt_tokens + completion_tokens, ) model_response.usage = usage + return model_response + async def async_completion( + self, + model: str, + model_response: ModelResponse, + api_base: str, + logging_obj, + encoding, + headers: dict, + data: dict, + client=None, + ): + if client is None: + client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) -def embedding(): - # logic for parsing in - calling - parsing out model embedding calls - pass + response = await client.post(api_base, headers=headers, data=json.dumps(data)) + + if response.status_code != 200: + raise AnthropicError( + status_code=response.status_code, message=response.text + ) + + ## LOGGING + logging_obj.post_call( + input=data["prompt"], + api_key=headers.get("x-api-key"), + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + + response = self.process_response( + model_response=model_response, + response=response, + encoding=encoding, + prompt=data["prompt"], + model=model, + ) + return response + + async def async_streaming( + self, + model: str, + api_base: str, + logging_obj, + headers: dict, + data: Optional[dict], + client=None, + ): + if client is None: + client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) + + response = await client.post(api_base, headers=headers, data=json.dumps(data)) + + if response.status_code != 200: + raise AnthropicError( + status_code=response.status_code, message=response.text + ) + + completion_stream = response.aiter_lines() + + streamwrapper = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="anthropic_text", + logging_obj=logging_obj, + ) + return streamwrapper + + def completion( + self, + model: str, + messages: list, + api_base: str, + acompletion: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params=None, + litellm_params=None, + logger_fn=None, + headers={}, + client=None, + ): + headers = validate_environment(api_key, headers) + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, + ) + else: + prompt = prompt_factory( + model=model, messages=messages, custom_llm_provider="anthropic" + ) + + ## Load Config + config = litellm.AnthropicTextConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + data = { + "model": model, + "prompt": prompt, + **optional_params, + } + + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + + ## COMPLETION CALL + if "stream" in optional_params and optional_params["stream"] == True: + if acompletion == True: + return self.async_streaming( + model=model, + api_base=api_base, + logging_obj=logging_obj, + headers=headers, + data=data, + client=None, + ) + + if client is None: + client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) + + response = client.post( + api_base, + headers=headers, + data=json.dumps(data), + # stream=optional_params["stream"], + ) + + if response.status_code != 200: + raise AnthropicError( + status_code=response.status_code, message=response.text + ) + completion_stream = response.iter_lines() + stream_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="anthropic_text", + logging_obj=logging_obj, + ) + return stream_response + elif acompletion == True: + return self.async_completion( + model=model, + model_response=model_response, + api_base=api_base, + logging_obj=logging_obj, + encoding=encoding, + headers=headers, + data=data, + client=client, + ) + else: + if client is None: + client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) + response = client.post(api_base, headers=headers, data=json.dumps(data)) + if response.status_code != 200: + raise AnthropicError( + status_code=response.status_code, message=response.text + ) + + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + + response = self.process_response( + model_response=model_response, + response=response, + encoding=encoding, + prompt=data["prompt"], + model=model, + ) + return response + + def embedding(self): + # logic for parsing in - calling - parsing out model embedding calls + pass diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 67e6c80da6..dd03e7dbec 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -58,13 +58,16 @@ class AsyncHTTPHandler: class HTTPHandler: - def __init__(self, concurrent_limit=1000): + def __init__( + self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000 + ): # Create a client with a connection pool self.client = httpx.Client( + timeout=timeout, limits=httpx.Limits( max_connections=concurrent_limit, max_keepalive_connections=concurrent_limit, - ) + ), ) def close(self): diff --git a/litellm/main.py b/litellm/main.py index 11fef23fd8..f23347942d 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -67,6 +67,7 @@ from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.azure import AzureChatCompletion from .llms.azure_text import AzureTextCompletion from .llms.anthropic import AnthropicChatCompletion +from .llms.anthropic_text import AnthropicTextCompletion from .llms.huggingface_restapi import Huggingface from .llms.prompt_templates.factory import ( prompt_factory, @@ -99,6 +100,7 @@ dotenv.load_dotenv() # Loading env variables using dotenv openai_chat_completions = OpenAIChatCompletion() openai_text_completions = OpenAITextCompletion() anthropic_chat_completions = AnthropicChatCompletion() +anthropic_text_completions = AnthropicTextCompletion() azure_chat_completions = AzureChatCompletion() azure_text_completions = AzureTextCompletion() huggingface = Huggingface() @@ -1165,10 +1167,11 @@ def completion( or get_secret("ANTHROPIC_API_BASE") or "https://api.anthropic.com/v1/complete" ) - response = anthropic_text.completion( + response = anthropic_text_completions.completion( model=model, messages=messages, api_base=api_base, + acompletion=acompletion, custom_prompt_dict=litellm.custom_prompt_dict, model_response=model_response, print_verbose=print_verbose, diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 89b1811494..cbadb23c5d 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -492,6 +492,31 @@ def test_completion_claude2_1(): # test_completion_claude2_1() + +@pytest.mark.asyncio +async def test_acompletion_claude2_1(): + try: + litellm.set_verbose = True + print("claude2.1 test request") + messages = [ + { + "role": "system", + "content": "Your goal is generate a joke on the topic user gives.", + }, + {"role": "user", "content": "Generate a 3 liner joke for me"}, + ] + # test without max tokens + response = await litellm.acompletion(model="claude-2.1", messages=messages) + # Add any assertions here to check the response + print(response) + print(response.usage) + print(response.usage.completion_tokens) + print(response["usage"]["completion_tokens"]) + # print("new cost tracking") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + # def test_completion_oobabooga(): # try: # response = completion( diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 988a1a8e94..098869a285 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -380,6 +380,51 @@ def test_completion_claude_stream(): # test_completion_claude_stream() +def test_completion_claude_2_stream(): + litellm.set_verbose = True + response = completion( + model="claude-2", + messages=[{"role": "user", "content": "hello from litellm"}], + stream=True, + ) + complete_response = "" + # Add any assertions here to check the response + idx = 0 + for chunk in response: + print(chunk) + # print(chunk.choices[0].delta) + chunk, finished = streaming_format_tests(idx, chunk) + if finished: + break + complete_response += chunk + idx += 1 + if complete_response.strip() == "": + raise Exception("Empty response received") + print(f"completion_response: {complete_response}") + + +@pytest.mark.asyncio +async def test_acompletion_claude_2_stream(): + litellm.set_verbose = True + response = await litellm.acompletion( + model="claude-2", + messages=[{"role": "user", "content": "hello from litellm"}], + stream=True, + ) + complete_response = "" + # Add any assertions here to check the response + idx = 0 + async for chunk in response: + print(chunk) + # print(chunk.choices[0].delta) + chunk, finished = streaming_format_tests(idx, chunk) + if finished: + break + complete_response += chunk + idx += 1 + if complete_response.strip() == "": + raise Exception("Empty response received") + print(f"completion_response: {complete_response}") def test_completion_palm_stream(): diff --git a/litellm/utils.py b/litellm/utils.py index 8440e361a6..2d110df533 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8845,6 +8845,35 @@ class CustomStreamWrapper: self.holding_chunk = "" return hold, curr_chunk + def handle_anthropic_text_chunk(self, chunk): + str_line = chunk + if isinstance(chunk, bytes): # Handle binary data + str_line = chunk.decode("utf-8") # Convert bytes to string + text = "" + is_finished = False + finish_reason = None + if str_line.startswith("data:"): + data_json = json.loads(str_line[5:]) + type_chunk = data_json.get("type", None) + if type_chunk == "completion": + text = data_json.get("completion") + finish_reason = data_json.get("stop_reason") + if finish_reason is not None: + is_finished = True + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } + elif "error" in str_line: + raise ValueError(f"Unable to parse response. Original response: {str_line}") + else: + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } + def handle_anthropic_chunk(self, chunk): str_line = chunk if isinstance(chunk, bytes): # Handle binary data @@ -9532,6 +9561,14 @@ class CustomStreamWrapper: completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] + elif ( + self.custom_llm_provider + and self.custom_llm_provider == "anthropic_text" + ): + response_obj = self.handle_anthropic_text_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] elif self.model == "replicate" or self.custom_llm_provider == "replicate": response_obj = self.handle_replicate_chunk(chunk) completion_obj["content"] = response_obj["text"] @@ -10109,6 +10146,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "text-completion-openai" or self.custom_llm_provider == "azure_text" or self.custom_llm_provider == "anthropic" + or self.custom_llm_provider == "anthropic_text" or self.custom_llm_provider == "huggingface" or self.custom_llm_provider == "ollama" or self.custom_llm_provider == "ollama_chat"