diff --git a/enterprise/enterprise_callbacks/generic_api_callback.py b/enterprise/enterprise_callbacks/generic_api_callback.py index eddaa06716..cfeea7d696 100644 --- a/enterprise/enterprise_callbacks/generic_api_callback.py +++ b/enterprise/enterprise_callbacks/generic_api_callback.py @@ -3,7 +3,6 @@ #### What this does #### # On success, logs events to Promptlayer import dotenv, os -import requests from litellm.proxy._types import UserAPIKeyAuth from litellm.caching.caching import DualCache @@ -17,7 +16,6 @@ import traceback # On success + failure, log events to Supabase import dotenv, os -import requests import traceback import datetime, subprocess, sys import litellm, uuid @@ -116,7 +114,9 @@ class GenericAPILogger: print_verbose(f"\nGeneric Logger - Logging payload = {data}") # make request to endpoint with payload - response = requests.post(self.endpoint, json=data, headers=self.headers) + response = litellm.module_level_client.post( + self.endpoint, json=data, headers=self.headers + ) response_status = response.status_code response_text = response.text diff --git a/litellm/budget_manager.py b/litellm/budget_manager.py index 6be2d0418a..a17edcdbe8 100644 --- a/litellm/budget_manager.py +++ b/litellm/budget_manager.py @@ -13,8 +13,6 @@ import threading import time from typing import Literal, Optional, Union -import requests # type: ignore - import litellm from litellm.utils import ModelResponse @@ -58,7 +56,9 @@ class BudgetManager: # Load the user_dict from hosted db url = self.api_base + "/get_budget" data = {"project_name": self.project_name} - response = requests.post(url, headers=self.headers, json=data) + response = litellm.module_level_client.post( + url, headers=self.headers, json=data + ) response = response.json() if response["status"] == "error": self.user_dict = ( @@ -215,6 +215,8 @@ class BudgetManager: elif self.client_type == "hosted": url = self.api_base + "/set_budget" data = {"project_name": self.project_name, "user_dict": self.user_dict} - response = requests.post(url, headers=self.headers, json=data) + response = litellm.module_level_client.post( + url, headers=self.headers, json=data + ) response = response.json() return response diff --git a/litellm/integrations/argilla.py b/litellm/integrations/argilla.py index d4719591f0..1ec7924b6f 100644 --- a/litellm/integrations/argilla.py +++ b/litellm/integrations/argilla.py @@ -15,19 +15,20 @@ from typing import Any, Dict, List, Optional, TypedDict, Union import dotenv # type: ignore import httpx -import requests # type: ignore from pydantic import BaseModel # type: ignore import litellm from litellm._logging import verbose_logger from litellm.integrations.custom_batch_logger import CustomBatchLogger from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + get_content_from_model_response, +) from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, get_async_httpx_client, httpxSpecialProvider, ) -from litellm.litellm_core_utils.prompt_templates.common_utils import get_content_from_model_response from litellm.types.integrations.argilla import ( SUPPORTED_PAYLOAD_FIELDS, ArgillaCredentialsObject, @@ -223,7 +224,7 @@ class ArgillaLogger(CustomBatchLogger): headers = {"X-Argilla-Api-Key": argilla_api_key} try: - response = requests.post( + response = litellm.module_level_client.post( url=url, json=self.log_queue, headers=headers, diff --git a/litellm/integrations/athina.py b/litellm/integrations/athina.py index b6f5447d82..f669b7c9ac 100644 --- a/litellm/integrations/athina.py +++ b/litellm/integrations/athina.py @@ -1,5 +1,7 @@ import datetime +import litellm + class AthinaLogger: def __init__(self): @@ -27,8 +29,6 @@ class AthinaLogger: import json import traceback - import requests # type: ignore - try: is_stream = kwargs.get("stream", False) if is_stream: @@ -81,7 +81,7 @@ class AthinaLogger: if key in metadata: data[key] = metadata[key] - response = requests.post( + response = litellm.module_level_client.post( self.athina_logging_url, headers=self.headers, data=json.dumps(data, default=str), diff --git a/litellm/integrations/dynamodb.py b/litellm/integrations/dynamodb.py index b5882c3254..5257020b44 100644 --- a/litellm/integrations/dynamodb.py +++ b/litellm/integrations/dynamodb.py @@ -8,7 +8,6 @@ import uuid from typing import Any import dotenv -import requests # type: ignore import litellm diff --git a/litellm/integrations/greenscale.py b/litellm/integrations/greenscale.py index a27acae427..430c3d0abf 100644 --- a/litellm/integrations/greenscale.py +++ b/litellm/integrations/greenscale.py @@ -2,7 +2,7 @@ import json import traceback from datetime import datetime, timezone -import requests # type: ignore +import litellm class GreenscaleLogger: @@ -54,7 +54,7 @@ class GreenscaleLogger: if self.greenscale_logging_url is None: raise Exception("Greenscale Logger Error - No logging URL found") - response = requests.post( + response = litellm.module_level_client.post( self.greenscale_logging_url, headers=self.headers, data=json.dumps(data, default=str), diff --git a/litellm/integrations/helicone.py b/litellm/integrations/helicone.py index 3291e94366..013bf3c6cd 100644 --- a/litellm/integrations/helicone.py +++ b/litellm/integrations/helicone.py @@ -4,7 +4,6 @@ import os import traceback import dotenv -import requests # type: ignore import litellm from litellm._logging import verbose_logger @@ -179,7 +178,7 @@ class HeliconeLogger: }, }, # {"seconds": .., "milliseconds": ..} } - response = requests.post(url, headers=headers, json=data) + response = litellm.module_level_client.post(url, headers=headers, json=data) if response.status_code == 200: print_verbose("Helicone Logging - Success!") else: diff --git a/litellm/integrations/langsmith.py b/litellm/integrations/langsmith.py index fbcc4d081b..4c5ec17fc8 100644 --- a/litellm/integrations/langsmith.py +++ b/litellm/integrations/langsmith.py @@ -12,7 +12,6 @@ from typing import Any, Dict, List, Optional, TypedDict, Union import dotenv # type: ignore import httpx -import requests # type: ignore from pydantic import BaseModel # type: ignore import litellm @@ -481,7 +480,7 @@ class LangsmithLogger(CustomBatchLogger): langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"] url = f"{langsmith_api_base}/runs/{run_id}" - response = requests.get( + response = litellm.module_level_client.get( url=url, headers={"x-api-key": langsmith_api_key}, ) diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index b87f245240..73198d0ba7 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -9,9 +9,6 @@ import uuid from datetime import date, datetime, timedelta from typing import Optional, TypedDict, Union -import dotenv -import requests # type: ignore - import litellm from litellm._logging import print_verbose, verbose_logger from litellm.integrations.custom_logger import CustomLogger diff --git a/litellm/integrations/prometheus_services.py b/litellm/integrations/prometheus_services.py index df94ffcd84..407a8e698b 100644 --- a/litellm/integrations/prometheus_services.py +++ b/litellm/integrations/prometheus_services.py @@ -11,9 +11,6 @@ import traceback import uuid from typing import List, Optional, Union -import dotenv -import requests # type: ignore - import litellm from litellm._logging import print_verbose, verbose_logger from litellm.types.integrations.prometheus import LATENCY_BUCKETS diff --git a/litellm/integrations/prompt_layer.py b/litellm/integrations/prompt_layer.py index 8d62b50b05..190b995fa4 100644 --- a/litellm/integrations/prompt_layer.py +++ b/litellm/integrations/prompt_layer.py @@ -3,10 +3,10 @@ import os import traceback -import dotenv -import requests # type: ignore from pydantic import BaseModel +import litellm + class PromptLayerLogger: # Class variables or attributes @@ -47,7 +47,7 @@ class PromptLayerLogger: if isinstance(response_obj, BaseModel): response_obj = response_obj.model_dump() - request_response = requests.post( + request_response = litellm.module_level_client.post( "https://api.promptlayer.com/rest/track-request", json={ "function_name": "openai.ChatCompletion.create", @@ -74,7 +74,7 @@ class PromptLayerLogger: if "request_id" in response_json: if metadata: - response = requests.post( + response = litellm.module_level_client.post( "https://api.promptlayer.com/rest/track-metadata", json={ "request_id": response_json["request_id"], diff --git a/litellm/integrations/supabase.py b/litellm/integrations/supabase.py index ed094e7d70..7f64e0ff12 100644 --- a/litellm/integrations/supabase.py +++ b/litellm/integrations/supabase.py @@ -8,7 +8,6 @@ import sys import traceback import dotenv -import requests # type: ignore import litellm diff --git a/litellm/integrations/weights_biases.py b/litellm/integrations/weights_biases.py index f2384fdf4b..f835eb93e7 100644 --- a/litellm/integrations/weights_biases.py +++ b/litellm/integrations/weights_biases.py @@ -177,8 +177,6 @@ import os import traceback from datetime import datetime -import requests - class WeightsBiasesLogger: # Class variables or attributes diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 619178ca04..b8d428100e 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -570,40 +570,6 @@ class CustomStreamWrapper: ) return "" - def handle_ollama_stream(self, chunk): - try: - if isinstance(chunk, dict): - json_chunk = chunk - else: - json_chunk = json.loads(chunk) - if "error" in json_chunk: - raise Exception(f"Ollama Error - {json_chunk}") - - text = "" - is_finished = False - finish_reason = None - if json_chunk["done"] is True: - text = "" - is_finished = True - finish_reason = "stop" - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } - elif json_chunk["response"]: - print_verbose(f"delta content: {json_chunk}") - text = json_chunk["response"] - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } - else: - raise Exception(f"Ollama Error - {json_chunk}") - except Exception as e: - raise e - def handle_ollama_chat_stream(self, chunk): # for ollama_chat/ provider try: @@ -1111,12 +1077,6 @@ class CustomStreamWrapper: new_chunk = self.completion_stream[:chunk_size] completion_obj["content"] = new_chunk self.completion_stream = self.completion_stream[chunk_size:] - elif self.custom_llm_provider == "ollama": - response_obj = self.handle_ollama_stream(chunk) - completion_obj["content"] = response_obj["text"] - print_verbose(f"completion obj content: {completion_obj['content']}") - if response_obj["is_finished"]: - self.received_finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "ollama_chat": response_obj = self.handle_ollama_chat_stream(chunk) completion_obj["content"] = response_obj["text"] diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index 275e3b868d..f7df3b01c6 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -13,7 +13,6 @@ from functools import partial from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import httpx # type: ignore -import requests # type: ignore from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice import litellm diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 30f87d5456..39582d1314 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -16,7 +16,6 @@ from typing import ( ) import httpx -import requests import litellm from litellm.constants import RESPONSE_FORMAT_TOOL_NAME diff --git a/litellm/llms/azure/completion/handler.py b/litellm/llms/azure/completion/handler.py index 889d40d16f..16211926a4 100644 --- a/litellm/llms/azure/completion/handler.py +++ b/litellm/llms/azure/completion/handler.py @@ -4,11 +4,14 @@ import uuid from typing import Any, Callable, Optional, Union import httpx -import requests from openai import AsyncAzureOpenAI, AzureOpenAI import litellm from litellm import OpenAIConfig +from litellm.litellm_core_utils.prompt_templates.factory import ( + custom_prompt, + prompt_factory, +) from litellm.utils import ( Choices, CustomStreamWrapper, @@ -22,7 +25,6 @@ from litellm.utils import ( from ...base import BaseLLM from ...openai.completion.handler import OpenAITextCompletion from ...openai.completion.transformation import OpenAITextCompletionConfig -from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory from ..common_utils import AzureOpenAIError openai_text_completion_config = OpenAITextCompletionConfig() diff --git a/litellm/llms/baseten.py b/litellm/llms/baseten.py index ce0a5599bd..d0d42b6d1b 100644 --- a/litellm/llms/baseten.py +++ b/litellm/llms/baseten.py @@ -4,9 +4,8 @@ import time from enum import Enum from typing import Callable -import requests # type: ignore - -from litellm.utils import ModelResponse, Usage +import litellm +from litellm.types.utils import ModelResponse, Usage class BasetenError(Exception): @@ -71,7 +70,7 @@ def completion( additional_args={"complete_input_dict": data}, ) ## COMPLETION CALL - response = requests.post( + response = litellm.module_level_client.post( completion_url_fragment_1 + model + completion_url_fragment_2, headers=headers, data=json.dumps(data), diff --git a/litellm/llms/bedrock/base_aws_llm.py b/litellm/llms/bedrock/base_aws_llm.py index 9df9204b6c..2f225b7b14 100644 --- a/litellm/llms/bedrock/base_aws_llm.py +++ b/litellm/llms/bedrock/base_aws_llm.py @@ -10,8 +10,6 @@ from litellm._logging import verbose_logger from litellm.caching.caching import DualCache, InMemoryCache from litellm.secret_managers.main import get_secret, get_secret_str -from litellm.llms.base import BaseLLM - if TYPE_CHECKING: from botocore.credentials import Credentials else: @@ -37,7 +35,7 @@ class AwsAuthError(Exception): ) # Call the base class constructor with the parameters it needs -class BaseAWSLLM(BaseLLM): +class BaseAWSLLM: def __init__(self) -> None: self.iam_cache = DualCache() super().__init__() diff --git a/litellm/llms/bedrock/chat/invoke_handler.py b/litellm/llms/bedrock/chat/invoke_handler.py index 963e3fca59..6348a2bfe9 100644 --- a/litellm/llms/bedrock/chat/invoke_handler.py +++ b/litellm/llms/bedrock/chat/invoke_handler.py @@ -25,7 +25,6 @@ from typing import ( ) import httpx # type: ignore -import requests # type: ignore import litellm from litellm import verbose_logger @@ -316,7 +315,7 @@ class BedrockLLM(BaseAWSLLM): def process_response( # noqa: PLR0915 self, model: str, - response: Union[requests.Response, httpx.Response], + response: httpx.Response, model_response: ModelResponse, stream: bool, logging_obj: Logging, @@ -1041,9 +1040,6 @@ class BedrockLLM(BaseAWSLLM): ) return streaming_response - def embedding(self, *args, **kwargs): - return super().embedding(*args, **kwargs) - def get_response_stream_shape(): global _response_stream_shape_cache diff --git a/litellm/llms/codestral/completion/handler.py b/litellm/llms/codestral/completion/handler.py index e04da501bf..0a9e86654e 100644 --- a/litellm/llms/codestral/completion/handler.py +++ b/litellm/llms/codestral/completion/handler.py @@ -12,7 +12,6 @@ from functools import partial from typing import Callable, List, Literal, Optional, Union import httpx # type: ignore -import requests # type: ignore import litellm from litellm import verbose_logger @@ -22,7 +21,6 @@ from litellm.litellm_core_utils.prompt_templates.factory import ( custom_prompt, prompt_factory, ) -from litellm.llms.base import BaseLLM from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, get_async_httpx_client, @@ -95,7 +93,7 @@ async def make_call( return completion_stream -class CodestralTextCompletion(BaseLLM): +class CodestralTextCompletion: def __init__(self) -> None: super().__init__() @@ -139,7 +137,7 @@ class CodestralTextCompletion(BaseLLM): def process_text_completion_response( self, model: str, - response: Union[requests.Response, httpx.Response], + response: httpx.Response, model_response: TextCompletionResponse, stream: bool, logging_obj: LiteLLMLogging, @@ -317,7 +315,7 @@ class CodestralTextCompletion(BaseLLM): ### SYNC STREAMING if stream is True: - response = requests.post( + response = litellm.module_level_client.post( completion_url, headers=headers, data=json.dumps(data), @@ -333,7 +331,7 @@ class CodestralTextCompletion(BaseLLM): ### SYNC COMPLETION else: - response = requests.post( + response = litellm.module_level_client.post( url=completion_url, headers=headers, data=json.dumps(data), diff --git a/litellm/llms/cohere/embed/handler.py b/litellm/llms/cohere/embed/handler.py index 5258df2b7f..2a9d7512e3 100644 --- a/litellm/llms/cohere/embed/handler.py +++ b/litellm/llms/cohere/embed/handler.py @@ -6,8 +6,7 @@ import types from enum import Enum from typing import Any, Callable, Optional, Union -import httpx # type: ignore -import requests # type: ignore +import httpx import litellm from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index aa91662918..6d37828498 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -491,7 +491,7 @@ class HTTPHandler: self, url: str, data: Optional[Union[dict, str]] = None, - json: Optional[Union[dict, str]] = None, + json: Optional[Union[dict, str, List]] = None, params: Optional[dict] = None, headers: Optional[dict] = None, stream: bool = False, diff --git a/litellm/llms/custom_llm.py b/litellm/llms/custom_llm.py index c19ba621b9..90f7875e66 100644 --- a/litellm/llms/custom_llm.py +++ b/litellm/llms/custom_llm.py @@ -29,8 +29,7 @@ from typing import ( Union, ) -import httpx # type: ignore -import requests # type: ignore +import httpx import litellm from litellm.litellm_core_utils.core_helpers import map_finish_reason @@ -46,6 +45,7 @@ from litellm.utils import ( from .base import BaseLLM + class CustomLLMError(Exception): # use this for all your exceptions def __init__( self, diff --git a/litellm/llms/deprecated_providers/aleph_alpha.py b/litellm/llms/deprecated_providers/aleph_alpha.py index bdea58e428..90da85d3b0 100644 --- a/litellm/llms/deprecated_providers/aleph_alpha.py +++ b/litellm/llms/deprecated_providers/aleph_alpha.py @@ -6,7 +6,6 @@ from enum import Enum from typing import Callable, Optional import httpx # type: ignore -import requests # type: ignore import litellm from litellm.utils import Choices, Message, ModelResponse, Usage @@ -240,7 +239,7 @@ def completion( additional_args={"complete_input_dict": data}, ) ## COMPLETION CALL - response = requests.post( + response = litellm.module_level_client.post( completion_url, headers=headers, data=json.dumps(data), diff --git a/litellm/llms/huggingface/chat/handler.py b/litellm/llms/huggingface/chat/handler.py index eadb62fb30..b1b7a6c2d9 100644 --- a/litellm/llms/huggingface/chat/handler.py +++ b/litellm/llms/huggingface/chat/handler.py @@ -20,7 +20,6 @@ from typing import ( ) import httpx -import requests import litellm from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj diff --git a/litellm/llms/ollama/common_utils.py b/litellm/llms/ollama/common_utils.py index 90cdad53fb..38f82ee7dc 100644 --- a/litellm/llms/ollama/common_utils.py +++ b/litellm/llms/ollama/common_utils.py @@ -10,3 +10,36 @@ class OllamaError(BaseLLMException): self, status_code: int, message: str, headers: Union[dict, httpx.Headers] ): super().__init__(status_code=status_code, message=message, headers=headers) + + +def _convert_image(image): + """ + Convert image to base64 encoded image if not already in base64 format + + If image is already in base64 format AND is a jpeg/png, return it + + If image is not JPEG/PNG, convert it to JPEG base64 format + """ + import base64 + import io + + try: + from PIL import Image + except Exception: + raise Exception( + "ollama image conversion failed please run `pip install Pillow`" + ) + + orig = image + if image.startswith("data:"): + image = image.split(",")[-1] + try: + image_data = Image.open(io.BytesIO(base64.b64decode(image))) + if image_data.format in ["JPEG", "PNG"]: + return image + except Exception: + return orig + jpeg_image = io.BytesIO() + image_data.convert("RGB").save(jpeg_image, "JPEG") + jpeg_image.seek(0) + return base64.b64encode(jpeg_image.getvalue()).decode("utf-8") diff --git a/litellm/llms/ollama/completion/handler.py b/litellm/llms/ollama/completion/handler.py index d50e7d5e64..8b6f26995d 100644 --- a/litellm/llms/ollama/completion/handler.py +++ b/litellm/llms/ollama/completion/handler.py @@ -1,3 +1,9 @@ +""" +Ollama /chat/completion calls handled in llm_http_handler.py + +[TODO]: migrate embeddings to a base handler as well. +""" + import asyncio import json import time @@ -8,10 +14,6 @@ from copy import deepcopy from itertools import chain from typing import Any, Dict, List, Optional -import aiohttp -import httpx # type: ignore -import requests # type: ignore - import litellm from litellm import verbose_logger from litellm.litellm_core_utils.prompt_templates.factory import ( @@ -31,370 +33,8 @@ from litellm.types.utils import ( from ..common_utils import OllamaError from .transformation import OllamaConfig - # ollama wants plain base64 jpeg/png files as images. strip any leading dataURI # and convert to jpeg if necessary. -def _convert_image(image): - import base64 - import io - - try: - from PIL import Image - except Exception: - raise Exception( - "ollama image conversion failed please run `pip install Pillow`" - ) - - orig = image - if image.startswith("data:"): - image = image.split(",")[-1] - try: - image_data = Image.open(io.BytesIO(base64.b64decode(image))) - if image_data.format in ["JPEG", "PNG"]: - return image - except Exception: - return orig - jpeg_image = io.BytesIO() - image_data.convert("RGB").save(jpeg_image, "JPEG") - jpeg_image.seek(0) - return base64.b64encode(jpeg_image.getvalue()).decode("utf-8") - - -# ollama implementation -def get_ollama_response( - model_response: ModelResponse, - model: str, - prompt: str, - optional_params: dict, - logging_obj: Any, - encoding: Any, - acompletion: bool = False, - api_base="http://localhost:11434", -): - if api_base.endswith("/api/generate"): - url = api_base - else: - url = f"{api_base}/api/generate" - - ## Load Config - config = litellm.OllamaConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - stream = optional_params.pop("stream", False) - format = optional_params.pop("format", None) - images = optional_params.pop("images", None) - data = { - "model": model, - "prompt": prompt, - "options": optional_params, - "stream": stream, - } - if format is not None: - data["format"] = format - if images is not None: - data["images"] = [_convert_image(image) for image in images] - - ## LOGGING - logging_obj.pre_call( - input=None, - api_key=None, - additional_args={ - "api_base": url, - "complete_input_dict": data, - "headers": {}, - "acompletion": acompletion, - }, - ) - if acompletion is True: - if stream is True: - response = ollama_async_streaming( - url=url, - data=data, - model_response=model_response, - encoding=encoding, - logging_obj=logging_obj, - ) - else: - response = ollama_acompletion( - url=url, - data=data, - model_response=model_response, - encoding=encoding, - logging_obj=logging_obj, - ) - return response - elif stream is True: - return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj) - - response = requests.post( - url=f"{url}", json={**data, "stream": stream}, timeout=litellm.request_timeout - ) - if response.status_code != 200: - raise OllamaError( - status_code=response.status_code, - message=response.text, - headers=dict(response.headers), - ) - - ## LOGGING - logging_obj.post_call( - input=prompt, - api_key="", - original_response=response.text, - additional_args={ - "headers": None, - "api_base": api_base, - }, - ) - - response_json = response.json() - - ## RESPONSE OBJECT - model_response.choices[0].finish_reason = "stop" - if data.get("format", "") == "json": - function_call = json.loads(response_json["response"]) - message = litellm.Message( - content=None, - tool_calls=[ - { - "id": f"call_{str(uuid.uuid4())}", - "function": { - "name": function_call["name"], - "arguments": json.dumps(function_call["arguments"]), - }, - "type": "function", - } - ], - ) - model_response.choices[0].message = message # type: ignore - model_response.choices[0].finish_reason = "tool_calls" - else: - model_response.choices[0].message.content = response_json["response"] # type: ignore - model_response.created = int(time.time()) - model_response.model = "ollama/" + model - prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore - completion_tokens = response_json.get( - "eval_count", len(response_json.get("message", dict()).get("content", "")) - ) - setattr( - model_response, - "usage", - litellm.Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), - ) - return model_response - - -def ollama_completion_stream(url, data, logging_obj): - with httpx.stream( - url=url, json=data, method="POST", timeout=litellm.request_timeout - ) as response: - try: - if response.status_code != 200: - raise OllamaError( - status_code=response.status_code, - message=str(response.read()), - headers=response.headers, - ) - - streamwrapper = litellm.CustomStreamWrapper( - completion_stream=response.iter_lines(), - model=data["model"], - custom_llm_provider="ollama", - logging_obj=logging_obj, - ) - # If format is JSON, this was a function call - # Gather all chunks and return the function call as one delta to simplify parsing - if data.get("format", "") == "json": - first_chunk = next(streamwrapper) - content_chunks = [] - for chunk in chain([first_chunk], streamwrapper): - content_chunk = chunk.choices[0] - if ( - isinstance(content_chunk, StreamingChoices) - and hasattr(content_chunk, "delta") - and hasattr(content_chunk.delta, "content") - and content_chunk.delta.content is not None - ): - content_chunks.append(content_chunk.delta.content) - response_content = "".join(content_chunks) - - function_call = json.loads(response_content) - delta = litellm.utils.Delta( - content=None, - tool_calls=[ - { - "id": f"call_{str(uuid.uuid4())}", - "function": { - "name": function_call["name"], - "arguments": json.dumps(function_call["arguments"]), - }, - "type": "function", - } - ], - ) - model_response = first_chunk - model_response.choices[0].delta = delta # type: ignore - model_response.choices[0].finish_reason = "tool_calls" - yield model_response - else: - for transformed_chunk in streamwrapper: - yield transformed_chunk - except Exception as e: - raise e - - -async def ollama_async_streaming(url, data, model_response, encoding, logging_obj): - try: - _async_http_client = get_async_httpx_client( - llm_provider=litellm.LlmProviders.OLLAMA - ) - client = _async_http_client.client - async with client.stream( - url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout - ) as response: - if response.status_code != 200: - raise OllamaError( - status_code=response.status_code, - message=str(await response.aread()), - headers=dict(response.headers), - ) - - streamwrapper = litellm.CustomStreamWrapper( - completion_stream=response.aiter_lines(), - model=data["model"], - custom_llm_provider="ollama", - logging_obj=logging_obj, - ) - - # If format is JSON, this was a function call - # Gather all chunks and return the function call as one delta to simplify parsing - if data.get("format", "") == "json": - first_chunk = await anext(streamwrapper) # noqa F821 - chunk_choice = first_chunk.choices[0] - if ( - isinstance(chunk_choice, StreamingChoices) - and hasattr(chunk_choice, "delta") - and hasattr(chunk_choice.delta, "content") - ): - first_chunk_content = chunk_choice.delta.content or "" - else: - first_chunk_content = "" - - content_chunks = [] - async for chunk in streamwrapper: - chunk_choice = chunk.choices[0] - if ( - isinstance(chunk_choice, StreamingChoices) - and hasattr(chunk_choice, "delta") - and hasattr(chunk_choice.delta, "content") - ): - content_chunks.append(chunk_choice.delta.content) - response_content = first_chunk_content + "".join(content_chunks) - function_call = json.loads(response_content) - delta = litellm.utils.Delta( - content=None, - tool_calls=[ - { - "id": f"call_{str(uuid.uuid4())}", - "function": { - "name": function_call["name"], - "arguments": json.dumps(function_call["arguments"]), - }, - "type": "function", - } - ], - ) - model_response = first_chunk - model_response.choices[0].delta = delta # type: ignore - model_response.choices[0].finish_reason = "tool_calls" - yield model_response - else: - async for transformed_chunk in streamwrapper: - yield transformed_chunk - except Exception as e: - raise e # don't use verbose_logger.exception, if exception is raised - - -async def ollama_acompletion( - url, data, model_response: litellm.ModelResponse, encoding, logging_obj -): - data["stream"] = False - try: - timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes - async with aiohttp.ClientSession(timeout=timeout) as session: - resp = await session.post(url, json=data) - - if resp.status != 200: - text = await resp.text() - raise OllamaError( - status_code=resp.status, - message=text, - headers=dict(resp.headers), - ) - - ## LOGGING - logging_obj.post_call( - input=data["prompt"], - api_key="", - original_response=resp.text, - additional_args={ - "headers": None, - "api_base": url, - }, - ) - - response_json = await resp.json() - ## RESPONSE OBJECT - model_response.choices[0].finish_reason = "stop" - if data.get("format", "") == "json": - function_call = json.loads(response_json["response"]) - message = litellm.Message( - content=None, - tool_calls=[ - { - "id": f"call_{str(uuid.uuid4())}", - "function": { - "name": function_call.get( - "name", function_call.get("function", None) - ), - "arguments": json.dumps(function_call["arguments"]), - }, - "type": "function", - } - ], - ) - model_response.choices[0].message = message # type: ignore - model_response.choices[0].finish_reason = "tool_calls" - else: - model_response.choices[0].message.content = response_json["response"] # type: ignore - model_response.created = int(time.time()) - model_response.model = "ollama/" + data["model"] - prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore - completion_tokens = response_json.get( - "eval_count", - len(response_json.get("message", dict()).get("content", "")), - ) - setattr( - model_response, - "usage", - litellm.Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), - ) - return model_response - except Exception as e: - raise e # don't use verbose_logger.exception, if exception is raised - async def ollama_aembeddings( api_base: str, @@ -432,39 +72,18 @@ async def ollama_aembeddings( total_input_tokens = 0 output_data = [] - timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes - async with aiohttp.ClientSession(timeout=timeout) as session: - ## LOGGING - logging_obj.pre_call( - input=None, - api_key=None, - additional_args={ - "api_base": url, - "complete_input_dict": data, - "headers": {}, - }, - ) + response = await litellm.module_level_aclient.post(url=url, json=data) - response = await session.post(url, json=data) + response_json = await response.json() - if response.status != 200: - text = await response.text() - raise OllamaError( - status_code=response.status, - message=text, - headers=dict(response.headers), - ) + embeddings: List[List[float]] = response_json["embeddings"] + for idx, emb in enumerate(embeddings): + output_data.append({"object": "embedding", "index": idx, "embedding": emb}) - response_json = await response.json() - - embeddings: List[List[float]] = response_json["embeddings"] - for idx, emb in enumerate(embeddings): - output_data.append({"object": "embedding", "index": idx, "embedding": emb}) - - input_tokens = response_json.get("prompt_eval_count") or len( - encoding.encode("".join(prompt for prompt in prompts)) - ) - total_input_tokens += input_tokens + input_tokens = response_json.get("prompt_eval_count") or len( + encoding.encode("".join(prompt for prompt in prompts)) + ) + total_input_tokens += input_tokens model_response.object = "list" model_response.data = output_data diff --git a/litellm/llms/ollama/completion/transformation.py b/litellm/llms/ollama/completion/transformation.py index cc5fddf9f7..c77fe7f028 100644 --- a/litellm/llms/ollama/completion/transformation.py +++ b/litellm/llms/ollama/completion/transformation.py @@ -1,20 +1,34 @@ +import json +import time import types -from typing import TYPE_CHECKING, Any, List, Optional, Union +import uuid +from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union from httpx._models import Headers, Response import litellm +from litellm.litellm_core_utils.prompt_templates.factory import ( + convert_to_ollama_image, + custom_prompt, + ollama_pt, +) +from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.secret_managers.main import get_secret_str -from litellm.types.llms.openai import AllMessageValues +from litellm.types.llms.openai import ( + AllMessageValues, + ChatCompletionToolCallChunk, + ChatCompletionUsageBlock, +) from litellm.types.utils import ( + GenericStreamingChunk, ModelInfo, ModelResponse, ProviderField, StreamingChoices, ) -from ..common_utils import OllamaError +from ..common_utils import OllamaError, _convert_image if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj @@ -247,7 +261,47 @@ class OllamaConfig(BaseConfig): api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: - raise NotImplementedError("transformation currently done in handler.py") + response_json = raw_response.json() + ## RESPONSE OBJECT + model_response.choices[0].finish_reason = "stop" + if request_data.get("format", "") == "json": + function_call = json.loads(response_json["response"]) + message = litellm.Message( + content=None, + tool_calls=[ + { + "id": f"call_{str(uuid.uuid4())}", + "function": { + "name": function_call["name"], + "arguments": json.dumps(function_call["arguments"]), + }, + "type": "function", + } + ], + ) + model_response.choices[0].message = message # type: ignore + model_response.choices[0].finish_reason = "tool_calls" + else: + model_response.choices[0].message.content = response_json["response"] # type: ignore + model_response.created = int(time.time()) + model_response.model = "ollama/" + model + _prompt = request_data.get("prompt", "") + prompt_tokens = response_json.get( + "prompt_eval_count", len(encoding.encode(_prompt, disallowed_special=())) # type: ignore + ) + completion_tokens = response_json.get( + "eval_count", len(response_json.get("message", dict()).get("content", "")) + ) + setattr( + model_response, + "usage", + litellm.Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + return model_response def transform_request( self, @@ -257,7 +311,46 @@ class OllamaConfig(BaseConfig): litellm_params: dict, headers: dict, ) -> dict: - raise NotImplementedError("transformation currently done in handler.py") + custom_prompt_dict = ( + litellm_params.get("custom_prompt_dict") or litellm.custom_prompt_dict + ) + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + ollama_prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, + ) + else: + modified_prompt = ollama_pt(model=model, messages=messages) + if isinstance(modified_prompt, dict): + ollama_prompt, images = ( + modified_prompt["prompt"], + modified_prompt["images"], + ) + optional_params["images"] = images + else: + ollama_prompt = modified_prompt + stream = optional_params.pop("stream", False) + format = optional_params.pop("format", None) + images = optional_params.pop("images", None) + data = { + "model": model, + "prompt": ollama_prompt, + "options": optional_params, + "stream": stream, + } + + if format is not None: + data["format"] = format + if images is not None: + data["images"] = [ + _convert_image(convert_to_ollama_image(image)) for image in images + ] + + return data def validate_environment( self, @@ -267,4 +360,77 @@ class OllamaConfig(BaseConfig): optional_params: dict, api_key: Optional[str] = None, ) -> dict: - raise NotImplementedError("validation currently done in handler.py") + return headers + + def get_complete_url(self, api_base: str, model: str) -> str: + """ + OPTIONAL + + Get the complete url for the request + + Some providers need `model` in `api_base` + """ + if api_base.endswith("/api/generate"): + url = api_base + else: + url = f"{api_base}/api/generate" + + return url + + def get_model_response_iterator( + self, + streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], + sync_stream: bool, + json_mode: Optional[bool] = False, + ): + return OllamaTextCompletionResponseIterator( + streaming_response=streaming_response, + sync_stream=sync_stream, + json_mode=json_mode, + ) + + +class OllamaTextCompletionResponseIterator(BaseModelResponseIterator): + def _handle_string_chunk(self, str_line: str) -> GenericStreamingChunk: + return self.chunk_parser(json.loads(str_line)) + + def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: + try: + if "error" in chunk: + raise Exception(f"Ollama Error - {chunk}") + + text = "" + is_finished = False + finish_reason = None + if chunk["done"] is True: + text = "" + is_finished = True + finish_reason = "stop" + prompt_eval_count: Optional[int] = chunk.get("prompt_eval_count", None) + eval_count: Optional[int] = chunk.get("eval_count", None) + + usage: Optional[ChatCompletionUsageBlock] = None + if prompt_eval_count is not None and eval_count is not None: + usage = ChatCompletionUsageBlock( + prompt_tokens=prompt_eval_count, + completion_tokens=eval_count, + total_tokens=prompt_eval_count + eval_count, + ) + return GenericStreamingChunk( + text=text, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + ) + elif chunk["response"]: + text = chunk["response"] + return GenericStreamingChunk( + text=text, + is_finished=is_finished, + finish_reason="stop", + usage=None, + ) + else: + raise Exception(f"Unable to parse ollama chunk - {chunk}") + except Exception as e: + raise e diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py index a0ccb81730..5fb35ba2bf 100644 --- a/litellm/llms/ollama_chat.py +++ b/litellm/llms/ollama_chat.py @@ -8,7 +8,6 @@ from typing import Any, List, Optional import aiohttp import httpx -import requests from pydantic import BaseModel import litellm @@ -297,13 +296,14 @@ def get_ollama_response( # noqa: PLR0915 url=url, api_key=api_key, data=data, logging_obj=logging_obj ) - _request = { - "url": f"{url}", - "json": data, - } + headers: Optional[dict] = None if api_key is not None: - _request["headers"] = {"Authorization": "Bearer {}".format(api_key)} - response = requests.post(**_request) # type: ignore + headers = {"Authorization": "Bearer {}".format(api_key)} + response = litellm.module_level_client.post( + url=url, + json=data, + headers=headers, + ) if response.status_code != 200: raise OllamaError(status_code=response.status_code, message=response.text) diff --git a/litellm/llms/oobabooga/chat/oobabooga.py b/litellm/llms/oobabooga/chat/oobabooga.py index 96b50ebbc4..30eaa049e1 100644 --- a/litellm/llms/oobabooga/chat/oobabooga.py +++ b/litellm/llms/oobabooga/chat/oobabooga.py @@ -4,12 +4,14 @@ import time from enum import Enum from typing import Any, Callable, Optional -import requests # type: ignore - +import litellm +from litellm.litellm_core_utils.prompt_templates.factory import ( + custom_prompt, + prompt_factory, +) from litellm.llms.custom_httpx.http_handler import HTTPHandler, _get_httpx_client from litellm.utils import EmbeddingResponse, ModelResponse, Usage -from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory from ..common_utils import OobaboogaError from .transformation import OobaboogaConfig @@ -129,9 +131,9 @@ def embedding( messages=[], optional_params=optional_params, ) - response = requests.post(embeddings_url, headers=headers, json=data) - if not response.ok: - raise OobaboogaError(message=response.text, status_code=response.status_code) + response = litellm.module_level_client.post( + embeddings_url, headers=headers, json=data + ) completion_response = response.json() # Check for errors in response diff --git a/litellm/llms/openai_like/chat/handler.py b/litellm/llms/openai_like/chat/handler.py index b3b1488409..2252dfc9cc 100644 --- a/litellm/llms/openai_like/chat/handler.py +++ b/litellm/llms/openai_like/chat/handler.py @@ -13,8 +13,7 @@ from enum import Enum from functools import partial from typing import Any, Callable, List, Literal, Optional, Tuple, Union -import httpx # type: ignore -import requests # type: ignore +import httpx import litellm from litellm import LlmProviders diff --git a/litellm/llms/predibase/chat/handler.py b/litellm/llms/predibase/chat/handler.py index 7352c2204c..a798ed6b3c 100644 --- a/litellm/llms/predibase/chat/handler.py +++ b/litellm/llms/predibase/chat/handler.py @@ -12,7 +12,6 @@ from functools import partial from typing import Callable, List, Literal, Optional, Union import httpx # type: ignore -import requests # type: ignore import litellm import litellm.litellm_core_utils @@ -63,7 +62,7 @@ async def make_call( return completion_stream -class PredibaseChatCompletion(BaseLLM): +class PredibaseChatCompletion: def __init__(self) -> None: super().__init__() @@ -90,7 +89,7 @@ class PredibaseChatCompletion(BaseLLM): def process_response( # noqa: PLR0915 self, model: str, - response: Union[requests.Response, httpx.Response], + response: httpx.Response, model_response: ModelResponse, stream: bool, logging_obj: LiteLLMLoggingBaseClass, @@ -347,7 +346,7 @@ class PredibaseChatCompletion(BaseLLM): ### SYNC STREAMING if stream is True: - response = requests.post( + response = litellm.module_level_client.post( completion_url, headers=headers, data=json.dumps(data), @@ -363,7 +362,7 @@ class PredibaseChatCompletion(BaseLLM): return _response ### SYNC COMPLETION else: - response = requests.post( + response = litellm.module_level_client.post( url=completion_url, headers=headers, data=json.dumps(data), diff --git a/litellm/llms/sagemaker/completion/handler.py b/litellm/llms/sagemaker/completion/handler.py index 41744f2543..a8b68f910b 100644 --- a/litellm/llms/sagemaker/completion/handler.py +++ b/litellm/llms/sagemaker/completion/handler.py @@ -10,12 +10,16 @@ from enum import Enum from functools import partial from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union -import httpx # type: ignore -import requests # type: ignore +import httpx import litellm from litellm._logging import verbose_logger from litellm.litellm_core_utils.asyncify import asyncify +from litellm.litellm_core_utils.prompt_templates.factory import ( + custom_prompt, + prompt_factory, +) +from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, @@ -31,8 +35,6 @@ from litellm.utils import ( get_secret, ) -from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM -from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory from ..common_utils import AWSEventStreamDecoder, SagemakerError from .transformation import SagemakerConfig diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index 09b74ad992..d438dea2fa 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -24,23 +24,22 @@ from typing import ( ) import httpx # type: ignore -import requests # type: ignore import litellm import litellm.litellm_core_utils import litellm.litellm_core_utils.litellm_logging from litellm import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.litellm_core_utils.prompt_templates.factory import ( + convert_generic_image_chunk_to_openai_image_obj, + convert_to_anthropic_image_obj, +) from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, get_async_httpx_client, ) -from litellm.litellm_core_utils.prompt_templates.factory import ( - convert_generic_image_chunk_to_openai_image_obj, - convert_to_anthropic_image_obj, -) from litellm.types.llms.openai import ( AllMessageValues, ChatCompletionResponseMessage, diff --git a/litellm/llms/vertex_ai/vertex_ai_non_gemini.py b/litellm/llms/vertex_ai/vertex_ai_non_gemini.py index d365cfac14..8908ccc9f9 100644 --- a/litellm/llms/vertex_ai/vertex_ai_non_gemini.py +++ b/litellm/llms/vertex_ai/vertex_ai_non_gemini.py @@ -7,19 +7,18 @@ import uuid from enum import Enum from typing import Any, Callable, List, Literal, Optional, Union, cast -import httpx # type: ignore -import requests # type: ignore +import httpx from pydantic import BaseModel import litellm from litellm._logging import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason -from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS from litellm.litellm_core_utils.prompt_templates.factory import ( convert_to_anthropic_image_obj, convert_to_gemini_tool_call_invoke, convert_to_gemini_tool_call_result, ) +from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS from litellm.types.files import ( get_file_mime_type_for_file_type, get_file_type_from_extension, diff --git a/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/transformation.py b/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/transformation.py index 93031656c2..01f0e5c27b 100644 --- a/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/transformation.py +++ b/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/transformation.py @@ -9,11 +9,19 @@ import uuid from enum import Enum from typing import Any, Callable, List, Optional, Tuple, Union -import httpx # type: ignore -import requests # type: ignore +import httpx import litellm from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.litellm_core_utils.prompt_templates.factory import ( + construct_tool_use_system_prompt, + contains_tag, + custom_prompt, + extract_between_tags, + parse_xml_params, + prompt_factory, + response_schema_prompt, +) from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.types.llms.openai import ( AllMessageValues, @@ -24,15 +32,6 @@ from litellm.types.utils import ResponseFormatChunk from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from ....anthropic.chat.transformation import AnthropicConfig -from litellm.litellm_core_utils.prompt_templates.factory import ( - construct_tool_use_system_prompt, - contains_tag, - custom_prompt, - extract_between_tags, - parse_xml_params, - prompt_factory, - response_schema_prompt, -) class VertexAIError(Exception): diff --git a/litellm/llms/vllm/completion/handler.py b/litellm/llms/vllm/completion/handler.py index f8f1e54a1f..a64ed8974a 100644 --- a/litellm/llms/vllm/completion/handler.py +++ b/litellm/llms/vllm/completion/handler.py @@ -5,12 +5,13 @@ from enum import Enum from typing import Any, Callable import httpx -import requests # type: ignore +from litellm.litellm_core_utils.prompt_templates.factory import ( + custom_prompt, + prompt_factory, +) from litellm.utils import ModelResponse, Usage -from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory - llm = None diff --git a/litellm/main.py b/litellm/main.py index d8c4acf2ad..ad634d3f87 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2645,45 +2645,24 @@ def completion( # type: ignore # noqa: PLR0915 or get_secret("OLLAMA_API_BASE") or "http://localhost:11434" ) - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - if model in custom_prompt_dict: - # check if the model has a registered custom prompt - model_prompt_details = custom_prompt_dict[model] - ollama_prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages, - ) - else: - modified_prompt = ollama_pt(model=model, messages=messages) - if isinstance(modified_prompt, dict): - # for multimode models - ollama/llava prompt_factory returns a dict { - # "prompt": prompt, - # "images": images - # } - ollama_prompt, images = ( - modified_prompt["prompt"], - modified_prompt["images"], - ) - optional_params["images"] = images - else: - ollama_prompt = modified_prompt - ## LOGGING - generator = ollama.get_ollama_response( - api_base=api_base, + response = base_llm_http_handler.completion( model=model, - prompt=ollama_prompt, - optional_params=optional_params, - logging_obj=logging, + stream=stream, + messages=messages, acompletion=acompletion, + api_base=api_base, model_response=model_response, + optional_params=optional_params, + litellm_params=litellm_params, + custom_llm_provider="ollama", + timeout=timeout, + headers=headers, encoding=encoding, + api_key=api_key, + logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements + client=client, ) - if acompletion is True or optional_params.get("stream", False) is True: - return generator - response = generator elif custom_llm_provider == "ollama_chat": api_base = ( litellm.api_base @@ -2833,8 +2812,6 @@ def completion( # type: ignore # noqa: PLR0915 return response response = model_response elif custom_llm_provider == "custom": - import requests - url = litellm.api_base or api_base or "" if url is None or url == "": raise ValueError( @@ -2843,7 +2820,7 @@ def completion( # type: ignore # noqa: PLR0915 """ assume input to custom LLM api bases follow this format: - resp = requests.post( + resp = litellm.module_level_client.post( api_base, json={ 'model': 'meta-llama/Llama-2-13b-hf', # model name @@ -2859,7 +2836,7 @@ def completion( # type: ignore # noqa: PLR0915 """ prompt = " ".join([message["content"] for message in messages]) # type: ignore - resp = requests.post( + resp = litellm.module_level_client.post( url, json={ "model": model, @@ -2871,7 +2848,6 @@ def completion( # type: ignore # noqa: PLR0915 "top_k": kwargs.get("top_k", 40), }, }, - verify=litellm.ssl_verify, ) response_json = resp.json() """ diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 094828de17..c77010a202 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -303,7 +303,7 @@ def run_server( # noqa: PLR0915 return if model and "ollama" in model and api_base is None: run_ollama_serve() - import requests + import httpx if test_async is True: import concurrent @@ -319,7 +319,7 @@ def run_server( # noqa: PLR0915 ], } - response = requests.post("http://0.0.0.0:4000/queue/request", json=data) + response = httpx.post("http://0.0.0.0:4000/queue/request", json=data) response = response.json() @@ -327,7 +327,7 @@ def run_server( # noqa: PLR0915 try: url = response["url"] polling_url = f"{api_base}{url}" - polling_response = requests.get(polling_url) + polling_response = httpx.get(polling_url) polling_response = polling_response.json() print("\n RESPONSE FROM POLLING JOB", polling_response) # noqa status = polling_response["status"] @@ -378,7 +378,7 @@ def run_server( # noqa: PLR0915 if health is not False: print("\nLiteLLM: Health Testing models in config") # noqa - response = requests.get(url=f"http://{host}:{port}/health") + response = httpx.get(url=f"http://{host}:{port}/health") print(json.dumps(response.json(), indent=4)) # noqa return if test is not False: diff --git a/litellm/router_strategy/least_busy.py b/litellm/router_strategy/least_busy.py index b1a85440f1..95deb8e6c8 100644 --- a/litellm/router_strategy/least_busy.py +++ b/litellm/router_strategy/least_busy.py @@ -11,9 +11,6 @@ import random import traceback from typing import Optional -import dotenv # type: ignore -import requests - from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger diff --git a/litellm/router_strategy/lowest_tpm_rpm.py b/litellm/router_strategy/lowest_tpm_rpm.py index 08d8086ef7..c99dc6a076 100644 --- a/litellm/router_strategy/lowest_tpm_rpm.py +++ b/litellm/router_strategy/lowest_tpm_rpm.py @@ -6,10 +6,6 @@ import traceback from datetime import datetime from typing import Dict, List, Optional, Union -import dotenv -import requests -from pydantic import BaseModel - from litellm import token_counter from litellm._logging import verbose_router_logger from litellm.caching.caching import DualCache diff --git a/litellm/utils.py b/litellm/utils.py index 5aa9ba2ade..093b4844dc 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -43,7 +43,6 @@ import aiohttp import dotenv import httpx import openai -import requests import tiktoken from httpx import Proxy from httpx._utils import get_environment_proxies @@ -4175,7 +4174,7 @@ def get_max_tokens(model: str) -> Optional[int]: config_url = f"https://huggingface.co/{model_name}/raw/main/config.json" try: # Make the HTTP request to get the raw JSON file - response = requests.get(config_url) + response = litellm.module_level_client.get(config_url) response.raise_for_status() # Raise an exception for bad responses (4xx or 5xx) # Parse the JSON response @@ -4186,7 +4185,7 @@ def get_max_tokens(model: str) -> Optional[int]: return max_position_embeddings else: return None - except requests.exceptions.RequestException: + except Exception: return None try: @@ -4361,7 +4360,7 @@ def get_model_info( # noqa: PLR0915 try: # Make the HTTP request to get the raw JSON file - response = requests.get(config_url) + response = litellm.module_level_client.get(config_url) response.raise_for_status() # Raise an exception for bad responses (4xx or 5xx) # Parse the JSON response @@ -4374,7 +4373,7 @@ def get_model_info( # noqa: PLR0915 return max_position_embeddings else: return None - except requests.exceptions.RequestException: + except Exception: return None try: diff --git a/tests/documentation_tests/test_requests_lib_usage.py b/tests/documentation_tests/test_requests_lib_usage.py new file mode 100644 index 0000000000..0c01b93aa8 --- /dev/null +++ b/tests/documentation_tests/test_requests_lib_usage.py @@ -0,0 +1,183 @@ +""" +Prevent usage of 'requests' library in the codebase. +""" + +import os +import ast +import sys +from typing import List, Tuple + + +def find_requests_usage(directory: str) -> List[Tuple[str, int, str]]: + """ + Recursively search for Python files in the given directory + and find usages of the 'requests' library. + + Args: + directory (str): The root directory to search for Python files + + Returns: + List of tuples containing (file_path, line_number, usage_type) + """ + requests_usages = [] + + def is_likely_requests_usage(node): + """ + More precise check to avoid false positives + """ + try: + # Convert node to string representation + node_str = ast.unparse(node) + + # Specific checks to ensure it's the requests library + requests_identifiers = [ + # HTTP methods + "requests.get", + "requests.post", + "requests.put", + "requests.delete", + "requests.head", + "requests.patch", + "requests.options", + "requests.request", + "requests.session", + # Types and exceptions + "requests.Response", + "requests.Request", + "requests.Session", + "requests.ConnectionError", + "requests.HTTPError", + "requests.Timeout", + "requests.TooManyRedirects", + "requests.RequestException", + # Additional modules and attributes + "requests.api", + "requests.exceptions", + "requests.models", + "requests.auth", + "requests.cookies", + "requests.structures", + ] + + # Check for specific requests library identifiers + return any(identifier in node_str for identifier in requests_identifiers) + except: + return False + + def scan_file(file_path: str): + """ + Scan a single Python file for requests library usage + """ + try: + # Use utf-8-sig to handle files with BOM, ignore errors + with open(file_path, "r", encoding="utf-8-sig", errors="ignore") as file: + tree = ast.parse(file.read()) + + for node in ast.walk(tree): + # Check import statements + if isinstance(node, ast.Import): + for alias in node.names: + if alias.name == "requests": + requests_usages.append( + (file_path, node.lineno, f"Import: {alias.name}") + ) + + # Check import from statements + elif isinstance(node, ast.ImportFrom): + if node.module == "requests": + requests_usages.append( + (file_path, node.lineno, f"Import from: {node.module}") + ) + + # Check method calls + elif isinstance(node, ast.Call): + # More precise check for requests usage + try: + if is_likely_requests_usage(node.func): + requests_usages.append( + ( + file_path, + node.lineno, + f"Method Call: {ast.unparse(node.func)}", + ) + ) + except: + pass + + # Check attribute access + elif isinstance(node, ast.Attribute): + try: + # More precise check + if is_likely_requests_usage(node): + requests_usages.append( + ( + file_path, + node.lineno, + f"Attribute Access: {ast.unparse(node)}", + ) + ) + except: + pass + + except SyntaxError as e: + print(f"Syntax error in {file_path}: {e}", file=sys.stderr) + except Exception as e: + print(f"Error processing {file_path}: {e}", file=sys.stderr) + + # Recursively walk through directory + for root, dirs, files in os.walk(directory): + # Remove virtual environment and cache directories from search + dirs[:] = [ + d + for d in dirs + if not any( + venv in d + for venv in [ + "venv", + "env", + "myenv", + ".venv", + "__pycache__", + ".pytest_cache", + ] + ) + ] + + for file in files: + if file.endswith(".py"): + full_path = os.path.join(root, file) + # Skip files in virtual environment or cache directories + if not any( + venv in full_path + for venv in [ + "venv", + "env", + "myenv", + ".venv", + "__pycache__", + ".pytest_cache", + ] + ): + scan_file(full_path) + + return requests_usages + + +def main(): + # Get directory from command line argument or use current directory + directory = "../../litellm" + + # Find requests library usages + results = find_requests_usage(directory) + + # Print results + if results: + print("Requests Library Usages Found:") + for file_path, line_num, usage_type in results: + print(f"{file_path}:{line_num} - {usage_type}") + else: + print("No requests library usages found.") + + +if __name__ == "__main__": + main() diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index 4c926dc806..b7d7473885 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -1940,10 +1940,11 @@ def test_ollama_image(): mock_response = MagicMock() mock_response.status_code = 200 mock_response.headers = {"Content-Type": "application/json"} + data_json = json.loads(kwargs["data"]) mock_response.json.return_value = { # return the image in the response so that it can be tested # against the original - "response": kwargs["json"]["images"] + "response": data_json["images"] } return mock_response @@ -1971,9 +1972,10 @@ def test_ollama_image(): [datauri_base64_data, datauri_base64_data], ] + client = HTTPHandler() for test in tests: try: - with patch("requests.post", side_effect=mock_post): + with patch.object(client, "post", side_effect=mock_post): response = completion( model="ollama/llava", messages=[ @@ -1988,6 +1990,7 @@ def test_ollama_image(): ], } ], + client=client, ) if not test[1]: # the conversion process may not always generate the same image, @@ -2387,8 +2390,8 @@ def test_completion_ollama_hosted(): response = completion( model="ollama/phi", messages=messages, - max_tokens=2, - api_base="https://test-ollama-endpoint.onrender.com", + max_tokens=20, + # api_base="https://test-ollama-endpoint.onrender.com", ) # Add any assertions here to check the response print(response) diff --git a/tests/local_testing/test_streaming.py b/tests/local_testing/test_streaming.py index 02ac8cb91b..22d043a97b 100644 --- a/tests/local_testing/test_streaming.py +++ b/tests/local_testing/test_streaming.py @@ -606,14 +606,14 @@ def test_completion_azure_function_calling_stream(): @pytest.mark.skip("Flaky ollama test - needs to be fixed") def test_completion_ollama_hosted_stream(): try: - litellm.set_verbose = True + # litellm.set_verbose = True response = completion( model="ollama/phi", messages=messages, - max_tokens=10, + max_tokens=100, num_retries=3, timeout=20, - api_base="https://test-ollama-endpoint.onrender.com", + # api_base="https://test-ollama-endpoint.onrender.com", stream=True, ) # Add any assertions here to check the response