diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index 0c7d22aa1b..a153c49304 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index ddeb7a4400..6fd96bed60 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/proxy/llm.py b/litellm/proxy/llm.py new file mode 100644 index 0000000000..878131697e --- /dev/null +++ b/litellm/proxy/llm.py @@ -0,0 +1,141 @@ +from typing import Dict, Optional +from collections import defaultdict +import threading +import os, subprocess, traceback, json +from fastapi import HTTPException +from fastapi.responses import StreamingResponse + +import backoff +import openai.error + +import litellm +import litellm.exceptions + +cost_dict: Dict[str, Dict[str, float]] = defaultdict(dict) +cost_dict_lock = threading.Lock() + +debug = False +##### HELPER FUNCTIONS ##### +def print_verbose(print_statement): + global debug + if debug: + print(print_statement) + +# for streaming +def data_generator(response): + print_verbose("inside generator") + for chunk in response: + print_verbose(f"returned chunk: {chunk}") + yield f"data: {json.dumps(chunk)}\n\n" + +def run_ollama_serve(): + command = ['ollama', 'serve'] + + with open(os.devnull, 'w') as devnull: + process = subprocess.Popen(command, stdout=devnull, stderr=devnull) + +##### ERROR HANDLING ##### +class RetryConstantError(Exception): + pass + + +class RetryExpoError(Exception): + pass + + +class UnknownLLMError(Exception): + pass + + +def handle_llm_exception(e: Exception, user_api_base: Optional[str]=None): + print(f"\033[1;31mLiteLLM.Exception: {str(e)}\033[0m") + if isinstance(e, openai.error.ServiceUnavailableError) and e.llm_provider == "ollama": + run_ollama_serve() + if isinstance(e, openai.error.InvalidRequestError) and e.llm_provider == "ollama": + completion_call_details = {} + completion_call_details["model"] = e.model + if user_api_base: + completion_call_details["api_base"] = user_api_base + else: + completion_call_details["api_base"] = None + print(f"\033[1;31mLiteLLM.Exception: Invalid API Call. Call details: Model: \033[1;37m{e.model}\033[1;31m; LLM Provider: \033[1;37m{e.llm_provider}\033[1;31m; Custom API Base - \033[1;37m{completion_call_details['api_base']}\033[1;31m\033[0m") + if completion_call_details["api_base"] == "http://localhost:11434": + print() + print("Trying to call ollama? Try `litellm --model ollama/llama2 --api_base http://localhost:11434`") + print() + if isinstance( + e, + ( + openai.error.APIError, + openai.error.TryAgain, + openai.error.Timeout, + openai.error.ServiceUnavailableError, + ), + ): + raise RetryConstantError from e + elif isinstance(e, openai.error.RateLimitError): + raise RetryExpoError from e + elif isinstance( + e, + ( + openai.error.APIConnectionError, + openai.error.InvalidRequestError, + openai.error.AuthenticationError, + openai.error.PermissionError, + openai.error.InvalidAPIType, + openai.error.SignatureVerificationError, + ), + ): + raise e + else: + raise UnknownLLMError from e + + +@backoff.on_exception( + wait_gen=backoff.constant, + exception=RetryConstantError, + max_tries=3, + interval=3, +) +@backoff.on_exception( + wait_gen=backoff.expo, + exception=RetryExpoError, + jitter=backoff.full_jitter, + max_value=100, + factor=1.5, +) + +def litellm_completion(data: Dict, + type: str, + user_model: Optional[str], + user_temperature: Optional[str], + user_max_tokens: Optional[int], + user_api_base: Optional[str], + user_headers: Optional[dict], + user_debug: bool) -> litellm.ModelResponse: + try: + global debug + debug = user_debug + if user_model: + data["model"] = user_model + # override with user settings + if user_temperature: + data["temperature"] = user_temperature + if user_max_tokens: + data["max_tokens"] = user_max_tokens + if user_api_base: + data["api_base"] = user_api_base + if user_headers: + data["headers"] = user_headers + if type == "completion": + response = litellm.text_completion(**data) + elif type == "chat_completion": + response = litellm.completion(**data) + if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses + return StreamingResponse(data_generator(response), media_type='text/event-stream') + print_verbose(f"response: {response}") + return response + except Exception as e: + print(e) + handle_llm_exception(e=e, user_api_base=user_api_base) + return {"message": "An error occurred"}, 500 diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 854f27aa36..90a9679213 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -23,6 +23,10 @@ except ImportError: import appdirs import tomli_w +try: + from .llm import litellm_completion +except ImportError as e: + from llm import litellm_completion import random list_of_messages = [ @@ -305,14 +309,6 @@ def deploy_proxy(model, api_base, debug, temperature, max_tokens, telemetry, dep return url - -# for streaming -def data_generator(response): - print_verbose("inside generator") - for chunk in response: - print_verbose(f"returned chunk: {chunk}") - yield f"data: {json.dumps(chunk)}\n\n" - def track_cost_callback( kwargs, # kwargs to completion completion_response, # response from completion @@ -433,49 +429,6 @@ litellm.input_callback = [logger] litellm.success_callback = [logger] litellm.failure_callback = [logger] -def litellm_completion(data, type): - try: - if user_model: - data["model"] = user_model - # override with user settings - if user_temperature: - data["temperature"] = user_temperature - if user_max_tokens: - data["max_tokens"] = user_max_tokens - if user_api_base: - data["api_base"] = user_api_base - if user_headers: - data["headers"] = user_headers - if type == "completion": - response = litellm.text_completion(**data) - elif type == "chat_completion": - response = litellm.completion(**data) - if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses - return StreamingResponse(data_generator(response), media_type='text/event-stream') - print_verbose(f"response: {response}") - return response - except Exception as e: - traceback.print_exc() - if "Invalid response object from API" in str(e): - completion_call_details = {} - if user_model: - completion_call_details["model"] = user_model - else: - completion_call_details["model"] = data['model'] - - if user_api_base: - completion_call_details["api_base"] = user_api_base - else: - completion_call_details["api_base"] = None - print(f"\033[1;31mLiteLLM.Exception: Invalid API Call. Call details: Model: \033[1;37m{completion_call_details['model']}\033[1;31m; LLM Provider: \033[1;37m{e.llm_provider}\033[1;31m; Custom API Base - \033[1;37m{completion_call_details['api_base']}\033[1;31m\033[0m") - if completion_call_details["api_base"] == "http://localhost:11434": - print() - print("Trying to call ollama? Try `litellm --model ollama/llama2 --api_base http://localhost:11434`") - print() - else: - print(f"\033[1;31mLiteLLM.Exception: {str(e)}\033[0m") - return {"message": "An error occurred"}, 500 - #### API ENDPOINTS #### @router.get("/models") # if project requires model list def model_list(): @@ -494,12 +447,12 @@ def model_list(): @router.post("/completions") async def completion(request: Request): data = await request.json() - return litellm_completion(data=data, type="completion") + return litellm_completion(data=data, type="completion", user_model=user_model, user_temperature=user_temperature, user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug) @router.post("/chat/completions") async def chat_completion(request: Request): data = await request.json() - response = litellm_completion(data, type="chat_completion") + response = litellm_completion(data, type="chat_completion", user_model=user_model, user_temperature=user_temperature, user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug) return response diff --git a/litellm/utils.py b/litellm/utils.py index 64499381bf..478eeda36a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3092,14 +3092,31 @@ def exception_type( raise original_exception raise original_exception elif custom_llm_provider == "ollama": - error_str = original_exception.get("error", "") + if isinstance(original_exception, dict): + error_str = original_exception.get("error", "") + else: + error_str = str(original_exception) if "no such file or directory" in error_str: exception_mapping_worked = True raise InvalidRequestError( - message=f"Ollama Exception Invalid Model/Model not loaded - {original_exception}", + message=f"OllamaException: Invalid Model/Model not loaded - {original_exception}", model=model, llm_provider="ollama" ) + elif "Failed to establish a new connection" in error_str: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"OllamaException: {original_exception}", + llm_provider="ollama", + model=model + ) + elif "Invalid response object from API" in error_str: + exception_mapping_worked = True + raise InvalidRequestError( + message=f"OllamaException: {original_exception}", + llm_provider="ollama", + model=model + ) elif custom_llm_provider == "vllm": if hasattr(original_exception, "status_code"): if original_exception.status_code == 0: