LITELLM: Remove requests library usage (#7235)

* fix(generic_api_callback.py): remove requests lib usage

* fix(budget_manager.py): remove requests lib usgae

* fix(main.py): cleanup requests lib usage

* fix(utils.py): remove requests lib usage

* fix(argilla.py): fix argilla test

* fix(athina.py): replace 'requests' lib usage with litellm module

* fix(greenscale.py): replace 'requests' lib usage with httpx

* fix: remove unused 'requests' lib import + replace usage in some places

* fix(prompt_layer.py): remove 'requests' lib usage from prompt layer

* fix(ollama_chat.py): remove 'requests' lib usage

* fix(baseten.py): replace 'requests' lib usage

* fix(codestral/): replace 'requests' lib usage

* fix(predibase/): replace 'requests' lib usage

* refactor: cleanup unused 'requests' lib imports

* fix(oobabooga.py): cleanup 'requests' lib usage

* fix(invoke_handler.py): remove unused 'requests' lib usage

* refactor: cleanup unused 'requests' lib import

* fix: fix linting errors

* refactor(ollama/): move ollama to using base llm http handler

removes 'requests' lib dep for ollama integration

* fix(ollama_chat.py): fix linting errors

* fix(ollama/completion/transformation.py): convert non-jpeg/png image to jpeg/png before passing to ollama
This commit is contained in:
Krish Dholakia 2024-12-17 12:50:04 -08:00 committed by GitHub
parent f628290ce7
commit 03e711e3e4
46 changed files with 523 additions and 612 deletions

View file

@ -3,7 +3,6 @@
#### What this does #### #### What this does ####
# On success, logs events to Promptlayer # On success, logs events to Promptlayer
import dotenv, os import dotenv, os
import requests
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
@ -17,7 +16,6 @@ import traceback
# On success + failure, log events to Supabase # On success + failure, log events to Supabase
import dotenv, os import dotenv, os
import requests
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid
@ -116,7 +114,9 @@ class GenericAPILogger:
print_verbose(f"\nGeneric Logger - Logging payload = {data}") print_verbose(f"\nGeneric Logger - Logging payload = {data}")
# make request to endpoint with payload # make request to endpoint with payload
response = requests.post(self.endpoint, json=data, headers=self.headers) response = litellm.module_level_client.post(
self.endpoint, json=data, headers=self.headers
)
response_status = response.status_code response_status = response.status_code
response_text = response.text response_text = response.text

View file

@ -13,8 +13,6 @@ import threading
import time import time
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import requests # type: ignore
import litellm import litellm
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
@ -58,7 +56,9 @@ class BudgetManager:
# Load the user_dict from hosted db # Load the user_dict from hosted db
url = self.api_base + "/get_budget" url = self.api_base + "/get_budget"
data = {"project_name": self.project_name} data = {"project_name": self.project_name}
response = requests.post(url, headers=self.headers, json=data) response = litellm.module_level_client.post(
url, headers=self.headers, json=data
)
response = response.json() response = response.json()
if response["status"] == "error": if response["status"] == "error":
self.user_dict = ( self.user_dict = (
@ -215,6 +215,8 @@ class BudgetManager:
elif self.client_type == "hosted": elif self.client_type == "hosted":
url = self.api_base + "/set_budget" url = self.api_base + "/set_budget"
data = {"project_name": self.project_name, "user_dict": self.user_dict} data = {"project_name": self.project_name, "user_dict": self.user_dict}
response = requests.post(url, headers=self.headers, json=data) response = litellm.module_level_client.post(
url, headers=self.headers, json=data
)
response = response.json() response = response.json()
return response return response

View file

@ -15,19 +15,20 @@ from typing import Any, Dict, List, Optional, TypedDict, Union
import dotenv # type: ignore import dotenv # type: ignore
import httpx import httpx
import requests # type: ignore
from pydantic import BaseModel # type: ignore from pydantic import BaseModel # type: ignore
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.integrations.custom_batch_logger import CustomBatchLogger from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_content_from_model_response,
)
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
get_async_httpx_client, get_async_httpx_client,
httpxSpecialProvider, httpxSpecialProvider,
) )
from litellm.litellm_core_utils.prompt_templates.common_utils import get_content_from_model_response
from litellm.types.integrations.argilla import ( from litellm.types.integrations.argilla import (
SUPPORTED_PAYLOAD_FIELDS, SUPPORTED_PAYLOAD_FIELDS,
ArgillaCredentialsObject, ArgillaCredentialsObject,
@ -223,7 +224,7 @@ class ArgillaLogger(CustomBatchLogger):
headers = {"X-Argilla-Api-Key": argilla_api_key} headers = {"X-Argilla-Api-Key": argilla_api_key}
try: try:
response = requests.post( response = litellm.module_level_client.post(
url=url, url=url,
json=self.log_queue, json=self.log_queue,
headers=headers, headers=headers,

View file

@ -1,5 +1,7 @@
import datetime import datetime
import litellm
class AthinaLogger: class AthinaLogger:
def __init__(self): def __init__(self):
@ -27,8 +29,6 @@ class AthinaLogger:
import json import json
import traceback import traceback
import requests # type: ignore
try: try:
is_stream = kwargs.get("stream", False) is_stream = kwargs.get("stream", False)
if is_stream: if is_stream:
@ -81,7 +81,7 @@ class AthinaLogger:
if key in metadata: if key in metadata:
data[key] = metadata[key] data[key] = metadata[key]
response = requests.post( response = litellm.module_level_client.post(
self.athina_logging_url, self.athina_logging_url,
headers=self.headers, headers=self.headers,
data=json.dumps(data, default=str), data=json.dumps(data, default=str),

View file

@ -8,7 +8,6 @@ import uuid
from typing import Any from typing import Any
import dotenv import dotenv
import requests # type: ignore
import litellm import litellm

View file

@ -2,7 +2,7 @@ import json
import traceback import traceback
from datetime import datetime, timezone from datetime import datetime, timezone
import requests # type: ignore import litellm
class GreenscaleLogger: class GreenscaleLogger:
@ -54,7 +54,7 @@ class GreenscaleLogger:
if self.greenscale_logging_url is None: if self.greenscale_logging_url is None:
raise Exception("Greenscale Logger Error - No logging URL found") raise Exception("Greenscale Logger Error - No logging URL found")
response = requests.post( response = litellm.module_level_client.post(
self.greenscale_logging_url, self.greenscale_logging_url,
headers=self.headers, headers=self.headers,
data=json.dumps(data, default=str), data=json.dumps(data, default=str),

View file

@ -4,7 +4,6 @@ import os
import traceback import traceback
import dotenv import dotenv
import requests # type: ignore
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
@ -179,7 +178,7 @@ class HeliconeLogger:
}, },
}, # {"seconds": .., "milliseconds": ..} }, # {"seconds": .., "milliseconds": ..}
} }
response = requests.post(url, headers=headers, json=data) response = litellm.module_level_client.post(url, headers=headers, json=data)
if response.status_code == 200: if response.status_code == 200:
print_verbose("Helicone Logging - Success!") print_verbose("Helicone Logging - Success!")
else: else:

View file

@ -12,7 +12,6 @@ from typing import Any, Dict, List, Optional, TypedDict, Union
import dotenv # type: ignore import dotenv # type: ignore
import httpx import httpx
import requests # type: ignore
from pydantic import BaseModel # type: ignore from pydantic import BaseModel # type: ignore
import litellm import litellm
@ -481,7 +480,7 @@ class LangsmithLogger(CustomBatchLogger):
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"] langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
url = f"{langsmith_api_base}/runs/{run_id}" url = f"{langsmith_api_base}/runs/{run_id}"
response = requests.get( response = litellm.module_level_client.get(
url=url, url=url,
headers={"x-api-key": langsmith_api_key}, headers={"x-api-key": langsmith_api_key},
) )

View file

@ -9,9 +9,6 @@ import uuid
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
from typing import Optional, TypedDict, Union from typing import Optional, TypedDict, Union
import dotenv
import requests # type: ignore
import litellm import litellm
from litellm._logging import print_verbose, verbose_logger from litellm._logging import print_verbose, verbose_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger

View file

@ -11,9 +11,6 @@ import traceback
import uuid import uuid
from typing import List, Optional, Union from typing import List, Optional, Union
import dotenv
import requests # type: ignore
import litellm import litellm
from litellm._logging import print_verbose, verbose_logger from litellm._logging import print_verbose, verbose_logger
from litellm.types.integrations.prometheus import LATENCY_BUCKETS from litellm.types.integrations.prometheus import LATENCY_BUCKETS

View file

@ -3,10 +3,10 @@
import os import os
import traceback import traceback
import dotenv
import requests # type: ignore
from pydantic import BaseModel from pydantic import BaseModel
import litellm
class PromptLayerLogger: class PromptLayerLogger:
# Class variables or attributes # Class variables or attributes
@ -47,7 +47,7 @@ class PromptLayerLogger:
if isinstance(response_obj, BaseModel): if isinstance(response_obj, BaseModel):
response_obj = response_obj.model_dump() response_obj = response_obj.model_dump()
request_response = requests.post( request_response = litellm.module_level_client.post(
"https://api.promptlayer.com/rest/track-request", "https://api.promptlayer.com/rest/track-request",
json={ json={
"function_name": "openai.ChatCompletion.create", "function_name": "openai.ChatCompletion.create",
@ -74,7 +74,7 @@ class PromptLayerLogger:
if "request_id" in response_json: if "request_id" in response_json:
if metadata: if metadata:
response = requests.post( response = litellm.module_level_client.post(
"https://api.promptlayer.com/rest/track-metadata", "https://api.promptlayer.com/rest/track-metadata",
json={ json={
"request_id": response_json["request_id"], "request_id": response_json["request_id"],

View file

@ -8,7 +8,6 @@ import sys
import traceback import traceback
import dotenv import dotenv
import requests # type: ignore
import litellm import litellm

View file

@ -177,8 +177,6 @@ import os
import traceback import traceback
from datetime import datetime from datetime import datetime
import requests
class WeightsBiasesLogger: class WeightsBiasesLogger:
# Class variables or attributes # Class variables or attributes

View file

@ -570,40 +570,6 @@ class CustomStreamWrapper:
) )
return "" return ""
def handle_ollama_stream(self, chunk):
try:
if isinstance(chunk, dict):
json_chunk = chunk
else:
json_chunk = json.loads(chunk)
if "error" in json_chunk:
raise Exception(f"Ollama Error - {json_chunk}")
text = ""
is_finished = False
finish_reason = None
if json_chunk["done"] is True:
text = ""
is_finished = True
finish_reason = "stop"
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
elif json_chunk["response"]:
print_verbose(f"delta content: {json_chunk}")
text = json_chunk["response"]
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
else:
raise Exception(f"Ollama Error - {json_chunk}")
except Exception as e:
raise e
def handle_ollama_chat_stream(self, chunk): def handle_ollama_chat_stream(self, chunk):
# for ollama_chat/ provider # for ollama_chat/ provider
try: try:
@ -1111,12 +1077,6 @@ class CustomStreamWrapper:
new_chunk = self.completion_stream[:chunk_size] new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[chunk_size:] self.completion_stream = self.completion_stream[chunk_size:]
elif self.custom_llm_provider == "ollama":
response_obj = self.handle_ollama_stream(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "ollama_chat": elif self.custom_llm_provider == "ollama_chat":
response_obj = self.handle_ollama_chat_stream(chunk) response_obj = self.handle_ollama_chat_stream(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]

View file

@ -13,7 +13,6 @@ from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore
from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice
import litellm import litellm

View file

@ -16,7 +16,6 @@ from typing import (
) )
import httpx import httpx
import requests
import litellm import litellm
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME from litellm.constants import RESPONSE_FORMAT_TOOL_NAME

View file

@ -4,11 +4,14 @@ import uuid
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import httpx import httpx
import requests
from openai import AsyncAzureOpenAI, AzureOpenAI from openai import AsyncAzureOpenAI, AzureOpenAI
import litellm import litellm
from litellm import OpenAIConfig from litellm import OpenAIConfig
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.utils import ( from litellm.utils import (
Choices, Choices,
CustomStreamWrapper, CustomStreamWrapper,
@ -22,7 +25,6 @@ from litellm.utils import (
from ...base import BaseLLM from ...base import BaseLLM
from ...openai.completion.handler import OpenAITextCompletion from ...openai.completion.handler import OpenAITextCompletion
from ...openai.completion.transformation import OpenAITextCompletionConfig from ...openai.completion.transformation import OpenAITextCompletionConfig
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
from ..common_utils import AzureOpenAIError from ..common_utils import AzureOpenAIError
openai_text_completion_config = OpenAITextCompletionConfig() openai_text_completion_config = OpenAITextCompletionConfig()

View file

@ -4,9 +4,8 @@ import time
from enum import Enum from enum import Enum
from typing import Callable from typing import Callable
import requests # type: ignore import litellm
from litellm.types.utils import ModelResponse, Usage
from litellm.utils import ModelResponse, Usage
class BasetenError(Exception): class BasetenError(Exception):
@ -71,7 +70,7 @@ def completion(
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = litellm.module_level_client.post(
completion_url_fragment_1 + model + completion_url_fragment_2, completion_url_fragment_1 + model + completion_url_fragment_2,
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),

View file

@ -10,8 +10,6 @@ from litellm._logging import verbose_logger
from litellm.caching.caching import DualCache, InMemoryCache from litellm.caching.caching import DualCache, InMemoryCache
from litellm.secret_managers.main import get_secret, get_secret_str from litellm.secret_managers.main import get_secret, get_secret_str
from litellm.llms.base import BaseLLM
if TYPE_CHECKING: if TYPE_CHECKING:
from botocore.credentials import Credentials from botocore.credentials import Credentials
else: else:
@ -37,7 +35,7 @@ class AwsAuthError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class BaseAWSLLM(BaseLLM): class BaseAWSLLM:
def __init__(self) -> None: def __init__(self) -> None:
self.iam_cache = DualCache() self.iam_cache = DualCache()
super().__init__() super().__init__()

View file

@ -25,7 +25,6 @@ from typing import (
) )
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore
import litellm import litellm
from litellm import verbose_logger from litellm import verbose_logger
@ -316,7 +315,7 @@ class BedrockLLM(BaseAWSLLM):
def process_response( # noqa: PLR0915 def process_response( # noqa: PLR0915
self, self,
model: str, model: str,
response: Union[requests.Response, httpx.Response], response: httpx.Response,
model_response: ModelResponse, model_response: ModelResponse,
stream: bool, stream: bool,
logging_obj: Logging, logging_obj: Logging,
@ -1041,9 +1040,6 @@ class BedrockLLM(BaseAWSLLM):
) )
return streaming_response return streaming_response
def embedding(self, *args, **kwargs):
return super().embedding(*args, **kwargs)
def get_response_stream_shape(): def get_response_stream_shape():
global _response_stream_shape_cache global _response_stream_shape_cache

View file

@ -12,7 +12,6 @@ from functools import partial
from typing import Callable, List, Literal, Optional, Union from typing import Callable, List, Literal, Optional, Union
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore
import litellm import litellm
from litellm import verbose_logger from litellm import verbose_logger
@ -22,7 +21,6 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt, custom_prompt,
prompt_factory, prompt_factory,
) )
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
get_async_httpx_client, get_async_httpx_client,
@ -95,7 +93,7 @@ async def make_call(
return completion_stream return completion_stream
class CodestralTextCompletion(BaseLLM): class CodestralTextCompletion:
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -139,7 +137,7 @@ class CodestralTextCompletion(BaseLLM):
def process_text_completion_response( def process_text_completion_response(
self, self,
model: str, model: str,
response: Union[requests.Response, httpx.Response], response: httpx.Response,
model_response: TextCompletionResponse, model_response: TextCompletionResponse,
stream: bool, stream: bool,
logging_obj: LiteLLMLogging, logging_obj: LiteLLMLogging,
@ -317,7 +315,7 @@ class CodestralTextCompletion(BaseLLM):
### SYNC STREAMING ### SYNC STREAMING
if stream is True: if stream is True:
response = requests.post( response = litellm.module_level_client.post(
completion_url, completion_url,
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),
@ -333,7 +331,7 @@ class CodestralTextCompletion(BaseLLM):
### SYNC COMPLETION ### SYNC COMPLETION
else: else:
response = requests.post( response = litellm.module_level_client.post(
url=completion_url, url=completion_url,
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),

View file

@ -6,8 +6,7 @@ import types
from enum import Enum from enum import Enum
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import httpx # type: ignore import httpx
import requests # type: ignore
import litellm import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj

View file

@ -491,7 +491,7 @@ class HTTPHandler:
self, self,
url: str, url: str,
data: Optional[Union[dict, str]] = None, data: Optional[Union[dict, str]] = None,
json: Optional[Union[dict, str]] = None, json: Optional[Union[dict, str, List]] = None,
params: Optional[dict] = None, params: Optional[dict] = None,
headers: Optional[dict] = None, headers: Optional[dict] = None,
stream: bool = False, stream: bool = False,

View file

@ -29,8 +29,7 @@ from typing import (
Union, Union,
) )
import httpx # type: ignore import httpx
import requests # type: ignore
import litellm import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
@ -46,6 +45,7 @@ from litellm.utils import (
from .base import BaseLLM from .base import BaseLLM
class CustomLLMError(Exception): # use this for all your exceptions class CustomLLMError(Exception): # use this for all your exceptions
def __init__( def __init__(
self, self,

View file

@ -6,7 +6,6 @@ from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore
import litellm import litellm
from litellm.utils import Choices, Message, ModelResponse, Usage from litellm.utils import Choices, Message, ModelResponse, Usage
@ -240,7 +239,7 @@ def completion(
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = litellm.module_level_client.post(
completion_url, completion_url,
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),

View file

@ -20,7 +20,6 @@ from typing import (
) )
import httpx import httpx
import requests
import litellm import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj

View file

@ -10,3 +10,36 @@ class OllamaError(BaseLLMException):
self, status_code: int, message: str, headers: Union[dict, httpx.Headers] self, status_code: int, message: str, headers: Union[dict, httpx.Headers]
): ):
super().__init__(status_code=status_code, message=message, headers=headers) super().__init__(status_code=status_code, message=message, headers=headers)
def _convert_image(image):
"""
Convert image to base64 encoded image if not already in base64 format
If image is already in base64 format AND is a jpeg/png, return it
If image is not JPEG/PNG, convert it to JPEG base64 format
"""
import base64
import io
try:
from PIL import Image
except Exception:
raise Exception(
"ollama image conversion failed please run `pip install Pillow`"
)
orig = image
if image.startswith("data:"):
image = image.split(",")[-1]
try:
image_data = Image.open(io.BytesIO(base64.b64decode(image)))
if image_data.format in ["JPEG", "PNG"]:
return image
except Exception:
return orig
jpeg_image = io.BytesIO()
image_data.convert("RGB").save(jpeg_image, "JPEG")
jpeg_image.seek(0)
return base64.b64encode(jpeg_image.getvalue()).decode("utf-8")

View file

@ -1,3 +1,9 @@
"""
Ollama /chat/completion calls handled in llm_http_handler.py
[TODO]: migrate embeddings to a base handler as well.
"""
import asyncio import asyncio
import json import json
import time import time
@ -8,10 +14,6 @@ from copy import deepcopy
from itertools import chain from itertools import chain
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import aiohttp
import httpx # type: ignore
import requests # type: ignore
import litellm import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.litellm_core_utils.prompt_templates.factory import ( from litellm.litellm_core_utils.prompt_templates.factory import (
@ -31,370 +33,8 @@ from litellm.types.utils import (
from ..common_utils import OllamaError from ..common_utils import OllamaError
from .transformation import OllamaConfig from .transformation import OllamaConfig
# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI # ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
# and convert to jpeg if necessary. # and convert to jpeg if necessary.
def _convert_image(image):
import base64
import io
try:
from PIL import Image
except Exception:
raise Exception(
"ollama image conversion failed please run `pip install Pillow`"
)
orig = image
if image.startswith("data:"):
image = image.split(",")[-1]
try:
image_data = Image.open(io.BytesIO(base64.b64decode(image)))
if image_data.format in ["JPEG", "PNG"]:
return image
except Exception:
return orig
jpeg_image = io.BytesIO()
image_data.convert("RGB").save(jpeg_image, "JPEG")
jpeg_image.seek(0)
return base64.b64encode(jpeg_image.getvalue()).decode("utf-8")
# ollama implementation
def get_ollama_response(
model_response: ModelResponse,
model: str,
prompt: str,
optional_params: dict,
logging_obj: Any,
encoding: Any,
acompletion: bool = False,
api_base="http://localhost:11434",
):
if api_base.endswith("/api/generate"):
url = api_base
else:
url = f"{api_base}/api/generate"
## Load Config
config = litellm.OllamaConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
stream = optional_params.pop("stream", False)
format = optional_params.pop("format", None)
images = optional_params.pop("images", None)
data = {
"model": model,
"prompt": prompt,
"options": optional_params,
"stream": stream,
}
if format is not None:
data["format"] = format
if images is not None:
data["images"] = [_convert_image(image) for image in images]
## LOGGING
logging_obj.pre_call(
input=None,
api_key=None,
additional_args={
"api_base": url,
"complete_input_dict": data,
"headers": {},
"acompletion": acompletion,
},
)
if acompletion is True:
if stream is True:
response = ollama_async_streaming(
url=url,
data=data,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
)
else:
response = ollama_acompletion(
url=url,
data=data,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
)
return response
elif stream is True:
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
response = requests.post(
url=f"{url}", json={**data, "stream": stream}, timeout=litellm.request_timeout
)
if response.status_code != 200:
raise OllamaError(
status_code=response.status_code,
message=response.text,
headers=dict(response.headers),
)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response.text,
additional_args={
"headers": None,
"api_base": api_base,
},
)
response_json = response.json()
## RESPONSE OBJECT
model_response.choices[0].finish_reason = "stop"
if data.get("format", "") == "json":
function_call = json.loads(response_json["response"])
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function",
}
],
)
model_response.choices[0].message = message # type: ignore
model_response.choices[0].finish_reason = "tool_calls"
else:
model_response.choices[0].message.content = response_json["response"] # type: ignore
model_response.created = int(time.time())
model_response.model = "ollama/" + model
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore
completion_tokens = response_json.get(
"eval_count", len(response_json.get("message", dict()).get("content", ""))
)
setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return model_response
def ollama_completion_stream(url, data, logging_obj):
with httpx.stream(
url=url, json=data, method="POST", timeout=litellm.request_timeout
) as response:
try:
if response.status_code != 200:
raise OllamaError(
status_code=response.status_code,
message=str(response.read()),
headers=response.headers,
)
streamwrapper = litellm.CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=data["model"],
custom_llm_provider="ollama",
logging_obj=logging_obj,
)
# If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json":
first_chunk = next(streamwrapper)
content_chunks = []
for chunk in chain([first_chunk], streamwrapper):
content_chunk = chunk.choices[0]
if (
isinstance(content_chunk, StreamingChoices)
and hasattr(content_chunk, "delta")
and hasattr(content_chunk.delta, "content")
and content_chunk.delta.content is not None
):
content_chunks.append(content_chunk.delta.content)
response_content = "".join(content_chunks)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function",
}
],
)
model_response = first_chunk
model_response.choices[0].delta = delta # type: ignore
model_response.choices[0].finish_reason = "tool_calls"
yield model_response
else:
for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
raise e
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
try:
_async_http_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.OLLAMA
)
client = _async_http_client.client
async with client.stream(
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
) as response:
if response.status_code != 200:
raise OllamaError(
status_code=response.status_code,
message=str(await response.aread()),
headers=dict(response.headers),
)
streamwrapper = litellm.CustomStreamWrapper(
completion_stream=response.aiter_lines(),
model=data["model"],
custom_llm_provider="ollama",
logging_obj=logging_obj,
)
# If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json":
first_chunk = await anext(streamwrapper) # noqa F821
chunk_choice = first_chunk.choices[0]
if (
isinstance(chunk_choice, StreamingChoices)
and hasattr(chunk_choice, "delta")
and hasattr(chunk_choice.delta, "content")
):
first_chunk_content = chunk_choice.delta.content or ""
else:
first_chunk_content = ""
content_chunks = []
async for chunk in streamwrapper:
chunk_choice = chunk.choices[0]
if (
isinstance(chunk_choice, StreamingChoices)
and hasattr(chunk_choice, "delta")
and hasattr(chunk_choice.delta, "content")
):
content_chunks.append(chunk_choice.delta.content)
response_content = first_chunk_content + "".join(content_chunks)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function",
}
],
)
model_response = first_chunk
model_response.choices[0].delta = delta # type: ignore
model_response.choices[0].finish_reason = "tool_calls"
yield model_response
else:
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
raise e # don't use verbose_logger.exception, if exception is raised
async def ollama_acompletion(
url, data, model_response: litellm.ModelResponse, encoding, logging_obj
):
data["stream"] = False
try:
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
async with aiohttp.ClientSession(timeout=timeout) as session:
resp = await session.post(url, json=data)
if resp.status != 200:
text = await resp.text()
raise OllamaError(
status_code=resp.status,
message=text,
headers=dict(resp.headers),
)
## LOGGING
logging_obj.post_call(
input=data["prompt"],
api_key="",
original_response=resp.text,
additional_args={
"headers": None,
"api_base": url,
},
)
response_json = await resp.json()
## RESPONSE OBJECT
model_response.choices[0].finish_reason = "stop"
if data.get("format", "") == "json":
function_call = json.loads(response_json["response"])
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"name": function_call.get(
"name", function_call.get("function", None)
),
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function",
}
],
)
model_response.choices[0].message = message # type: ignore
model_response.choices[0].finish_reason = "tool_calls"
else:
model_response.choices[0].message.content = response_json["response"] # type: ignore
model_response.created = int(time.time())
model_response.model = "ollama/" + data["model"]
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore
completion_tokens = response_json.get(
"eval_count",
len(response_json.get("message", dict()).get("content", "")),
)
setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return model_response
except Exception as e:
raise e # don't use verbose_logger.exception, if exception is raised
async def ollama_aembeddings( async def ollama_aembeddings(
api_base: str, api_base: str,
@ -432,39 +72,18 @@ async def ollama_aembeddings(
total_input_tokens = 0 total_input_tokens = 0
output_data = [] output_data = []
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes response = await litellm.module_level_aclient.post(url=url, json=data)
async with aiohttp.ClientSession(timeout=timeout) as session:
## LOGGING
logging_obj.pre_call(
input=None,
api_key=None,
additional_args={
"api_base": url,
"complete_input_dict": data,
"headers": {},
},
)
response = await session.post(url, json=data) response_json = await response.json()
if response.status != 200: embeddings: List[List[float]] = response_json["embeddings"]
text = await response.text() for idx, emb in enumerate(embeddings):
raise OllamaError( output_data.append({"object": "embedding", "index": idx, "embedding": emb})
status_code=response.status,
message=text,
headers=dict(response.headers),
)
response_json = await response.json() input_tokens = response_json.get("prompt_eval_count") or len(
encoding.encode("".join(prompt for prompt in prompts))
embeddings: List[List[float]] = response_json["embeddings"] )
for idx, emb in enumerate(embeddings): total_input_tokens += input_tokens
output_data.append({"object": "embedding", "index": idx, "embedding": emb})
input_tokens = response_json.get("prompt_eval_count") or len(
encoding.encode("".join(prompt for prompt in prompts))
)
total_input_tokens += input_tokens
model_response.object = "list" model_response.object = "list"
model_response.data = output_data model_response.data = output_data

View file

@ -1,20 +1,34 @@
import json
import time
import types import types
from typing import TYPE_CHECKING, Any, List, Optional, Union import uuid
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
from httpx._models import Headers, Response from httpx._models import Headers, Response
import litellm import litellm
from litellm.litellm_core_utils.prompt_templates.factory import (
convert_to_ollama_image,
custom_prompt,
ollama_pt,
)
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
)
from litellm.types.utils import ( from litellm.types.utils import (
GenericStreamingChunk,
ModelInfo, ModelInfo,
ModelResponse, ModelResponse,
ProviderField, ProviderField,
StreamingChoices, StreamingChoices,
) )
from ..common_utils import OllamaError from ..common_utils import OllamaError, _convert_image
if TYPE_CHECKING: if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
@ -247,7 +261,47 @@ class OllamaConfig(BaseConfig):
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,
) -> ModelResponse: ) -> ModelResponse:
raise NotImplementedError("transformation currently done in handler.py") response_json = raw_response.json()
## RESPONSE OBJECT
model_response.choices[0].finish_reason = "stop"
if request_data.get("format", "") == "json":
function_call = json.loads(response_json["response"])
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function",
}
],
)
model_response.choices[0].message = message # type: ignore
model_response.choices[0].finish_reason = "tool_calls"
else:
model_response.choices[0].message.content = response_json["response"] # type: ignore
model_response.created = int(time.time())
model_response.model = "ollama/" + model
_prompt = request_data.get("prompt", "")
prompt_tokens = response_json.get(
"prompt_eval_count", len(encoding.encode(_prompt, disallowed_special=())) # type: ignore
)
completion_tokens = response_json.get(
"eval_count", len(response_json.get("message", dict()).get("content", ""))
)
setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return model_response
def transform_request( def transform_request(
self, self,
@ -257,7 +311,46 @@ class OllamaConfig(BaseConfig):
litellm_params: dict, litellm_params: dict,
headers: dict, headers: dict,
) -> dict: ) -> dict:
raise NotImplementedError("transformation currently done in handler.py") custom_prompt_dict = (
litellm_params.get("custom_prompt_dict") or litellm.custom_prompt_dict
)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
ollama_prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
modified_prompt = ollama_pt(model=model, messages=messages)
if isinstance(modified_prompt, dict):
ollama_prompt, images = (
modified_prompt["prompt"],
modified_prompt["images"],
)
optional_params["images"] = images
else:
ollama_prompt = modified_prompt
stream = optional_params.pop("stream", False)
format = optional_params.pop("format", None)
images = optional_params.pop("images", None)
data = {
"model": model,
"prompt": ollama_prompt,
"options": optional_params,
"stream": stream,
}
if format is not None:
data["format"] = format
if images is not None:
data["images"] = [
_convert_image(convert_to_ollama_image(image)) for image in images
]
return data
def validate_environment( def validate_environment(
self, self,
@ -267,4 +360,77 @@ class OllamaConfig(BaseConfig):
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
) -> dict: ) -> dict:
raise NotImplementedError("validation currently done in handler.py") return headers
def get_complete_url(self, api_base: str, model: str) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
if api_base.endswith("/api/generate"):
url = api_base
else:
url = f"{api_base}/api/generate"
return url
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
return OllamaTextCompletionResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
class OllamaTextCompletionResponseIterator(BaseModelResponseIterator):
def _handle_string_chunk(self, str_line: str) -> GenericStreamingChunk:
return self.chunk_parser(json.loads(str_line))
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
if "error" in chunk:
raise Exception(f"Ollama Error - {chunk}")
text = ""
is_finished = False
finish_reason = None
if chunk["done"] is True:
text = ""
is_finished = True
finish_reason = "stop"
prompt_eval_count: Optional[int] = chunk.get("prompt_eval_count", None)
eval_count: Optional[int] = chunk.get("eval_count", None)
usage: Optional[ChatCompletionUsageBlock] = None
if prompt_eval_count is not None and eval_count is not None:
usage = ChatCompletionUsageBlock(
prompt_tokens=prompt_eval_count,
completion_tokens=eval_count,
total_tokens=prompt_eval_count + eval_count,
)
return GenericStreamingChunk(
text=text,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
)
elif chunk["response"]:
text = chunk["response"]
return GenericStreamingChunk(
text=text,
is_finished=is_finished,
finish_reason="stop",
usage=None,
)
else:
raise Exception(f"Unable to parse ollama chunk - {chunk}")
except Exception as e:
raise e

View file

@ -8,7 +8,6 @@ from typing import Any, List, Optional
import aiohttp import aiohttp
import httpx import httpx
import requests
from pydantic import BaseModel from pydantic import BaseModel
import litellm import litellm
@ -297,13 +296,14 @@ def get_ollama_response( # noqa: PLR0915
url=url, api_key=api_key, data=data, logging_obj=logging_obj url=url, api_key=api_key, data=data, logging_obj=logging_obj
) )
_request = { headers: Optional[dict] = None
"url": f"{url}",
"json": data,
}
if api_key is not None: if api_key is not None:
_request["headers"] = {"Authorization": "Bearer {}".format(api_key)} headers = {"Authorization": "Bearer {}".format(api_key)}
response = requests.post(**_request) # type: ignore response = litellm.module_level_client.post(
url=url,
json=data,
headers=headers,
)
if response.status_code != 200: if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text) raise OllamaError(status_code=response.status_code, message=response.text)

View file

@ -4,12 +4,14 @@ import time
from enum import Enum from enum import Enum
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
import requests # type: ignore import litellm
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.custom_httpx.http_handler import HTTPHandler, _get_httpx_client from litellm.llms.custom_httpx.http_handler import HTTPHandler, _get_httpx_client
from litellm.utils import EmbeddingResponse, ModelResponse, Usage from litellm.utils import EmbeddingResponse, ModelResponse, Usage
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
from ..common_utils import OobaboogaError from ..common_utils import OobaboogaError
from .transformation import OobaboogaConfig from .transformation import OobaboogaConfig
@ -129,9 +131,9 @@ def embedding(
messages=[], messages=[],
optional_params=optional_params, optional_params=optional_params,
) )
response = requests.post(embeddings_url, headers=headers, json=data) response = litellm.module_level_client.post(
if not response.ok: embeddings_url, headers=headers, json=data
raise OobaboogaError(message=response.text, status_code=response.status_code) )
completion_response = response.json() completion_response = response.json()
# Check for errors in response # Check for errors in response

View file

@ -13,8 +13,7 @@ from enum import Enum
from functools import partial from functools import partial
from typing import Any, Callable, List, Literal, Optional, Tuple, Union from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore import httpx
import requests # type: ignore
import litellm import litellm
from litellm import LlmProviders from litellm import LlmProviders

View file

@ -12,7 +12,6 @@ from functools import partial
from typing import Callable, List, Literal, Optional, Union from typing import Callable, List, Literal, Optional, Union
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore
import litellm import litellm
import litellm.litellm_core_utils import litellm.litellm_core_utils
@ -63,7 +62,7 @@ async def make_call(
return completion_stream return completion_stream
class PredibaseChatCompletion(BaseLLM): class PredibaseChatCompletion:
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -90,7 +89,7 @@ class PredibaseChatCompletion(BaseLLM):
def process_response( # noqa: PLR0915 def process_response( # noqa: PLR0915
self, self,
model: str, model: str,
response: Union[requests.Response, httpx.Response], response: httpx.Response,
model_response: ModelResponse, model_response: ModelResponse,
stream: bool, stream: bool,
logging_obj: LiteLLMLoggingBaseClass, logging_obj: LiteLLMLoggingBaseClass,
@ -347,7 +346,7 @@ class PredibaseChatCompletion(BaseLLM):
### SYNC STREAMING ### SYNC STREAMING
if stream is True: if stream is True:
response = requests.post( response = litellm.module_level_client.post(
completion_url, completion_url,
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),
@ -363,7 +362,7 @@ class PredibaseChatCompletion(BaseLLM):
return _response return _response
### SYNC COMPLETION ### SYNC COMPLETION
else: else:
response = requests.post( response = litellm.module_level_client.post(
url=completion_url, url=completion_url,
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),

View file

@ -10,12 +10,16 @@ from enum import Enum
from functools import partial from functools import partial
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
import httpx # type: ignore import httpx
import requests # type: ignore
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.asyncify import asyncify from litellm.litellm_core_utils.asyncify import asyncify
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
@ -31,8 +35,6 @@ from litellm.utils import (
get_secret, get_secret,
) )
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
from ..common_utils import AWSEventStreamDecoder, SagemakerError from ..common_utils import AWSEventStreamDecoder, SagemakerError
from .transformation import SagemakerConfig from .transformation import SagemakerConfig

View file

@ -24,23 +24,22 @@ from typing import (
) )
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore
import litellm import litellm
import litellm.litellm_core_utils import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging import litellm.litellm_core_utils.litellm_logging
from litellm import verbose_logger from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.prompt_templates.factory import (
convert_generic_image_chunk_to_openai_image_obj,
convert_to_anthropic_image_obj,
)
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.litellm_core_utils.prompt_templates.factory import (
convert_generic_image_chunk_to_openai_image_obj,
convert_to_anthropic_image_obj,
)
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues, AllMessageValues,
ChatCompletionResponseMessage, ChatCompletionResponseMessage,

View file

@ -7,19 +7,18 @@ import uuid
from enum import Enum from enum import Enum
from typing import Any, Callable, List, Literal, Optional, Union, cast from typing import Any, Callable, List, Literal, Optional, Union, cast
import httpx # type: ignore import httpx
import requests # type: ignore
from pydantic import BaseModel from pydantic import BaseModel
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.litellm_core_utils.prompt_templates.factory import ( from litellm.litellm_core_utils.prompt_templates.factory import (
convert_to_anthropic_image_obj, convert_to_anthropic_image_obj,
convert_to_gemini_tool_call_invoke, convert_to_gemini_tool_call_invoke,
convert_to_gemini_tool_call_result, convert_to_gemini_tool_call_result,
) )
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.types.files import ( from litellm.types.files import (
get_file_mime_type_for_file_type, get_file_mime_type_for_file_type,
get_file_type_from_extension, get_file_type_from_extension,

View file

@ -9,11 +9,19 @@ import uuid
from enum import Enum from enum import Enum
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
import httpx # type: ignore import httpx
import requests # type: ignore
import litellm import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.prompt_templates.factory import (
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory,
response_schema_prompt,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues, AllMessageValues,
@ -24,15 +32,6 @@ from litellm.types.utils import ResponseFormatChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ....anthropic.chat.transformation import AnthropicConfig from ....anthropic.chat.transformation import AnthropicConfig
from litellm.litellm_core_utils.prompt_templates.factory import (
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory,
response_schema_prompt,
)
class VertexAIError(Exception): class VertexAIError(Exception):

View file

@ -5,12 +5,13 @@ from enum import Enum
from typing import Any, Callable from typing import Any, Callable
import httpx import httpx
import requests # type: ignore
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
llm = None llm = None

View file

@ -2645,45 +2645,24 @@ def completion( # type: ignore # noqa: PLR0915
or get_secret("OLLAMA_API_BASE") or get_secret("OLLAMA_API_BASE")
or "http://localhost:11434" or "http://localhost:11434"
) )
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict response = base_llm_http_handler.completion(
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
ollama_prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
modified_prompt = ollama_pt(model=model, messages=messages)
if isinstance(modified_prompt, dict):
# for multimode models - ollama/llava prompt_factory returns a dict {
# "prompt": prompt,
# "images": images
# }
ollama_prompt, images = (
modified_prompt["prompt"],
modified_prompt["images"],
)
optional_params["images"] = images
else:
ollama_prompt = modified_prompt
## LOGGING
generator = ollama.get_ollama_response(
api_base=api_base,
model=model, model=model,
prompt=ollama_prompt, stream=stream,
optional_params=optional_params, messages=messages,
logging_obj=logging,
acompletion=acompletion, acompletion=acompletion,
api_base=api_base,
model_response=model_response, model_response=model_response,
optional_params=optional_params,
litellm_params=litellm_params,
custom_llm_provider="ollama",
timeout=timeout,
headers=headers,
encoding=encoding, encoding=encoding,
api_key=api_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
client=client,
) )
if acompletion is True or optional_params.get("stream", False) is True:
return generator
response = generator
elif custom_llm_provider == "ollama_chat": elif custom_llm_provider == "ollama_chat":
api_base = ( api_base = (
litellm.api_base litellm.api_base
@ -2833,8 +2812,6 @@ def completion( # type: ignore # noqa: PLR0915
return response return response
response = model_response response = model_response
elif custom_llm_provider == "custom": elif custom_llm_provider == "custom":
import requests
url = litellm.api_base or api_base or "" url = litellm.api_base or api_base or ""
if url is None or url == "": if url is None or url == "":
raise ValueError( raise ValueError(
@ -2843,7 +2820,7 @@ def completion( # type: ignore # noqa: PLR0915
""" """
assume input to custom LLM api bases follow this format: assume input to custom LLM api bases follow this format:
resp = requests.post( resp = litellm.module_level_client.post(
api_base, api_base,
json={ json={
'model': 'meta-llama/Llama-2-13b-hf', # model name 'model': 'meta-llama/Llama-2-13b-hf', # model name
@ -2859,7 +2836,7 @@ def completion( # type: ignore # noqa: PLR0915
""" """
prompt = " ".join([message["content"] for message in messages]) # type: ignore prompt = " ".join([message["content"] for message in messages]) # type: ignore
resp = requests.post( resp = litellm.module_level_client.post(
url, url,
json={ json={
"model": model, "model": model,
@ -2871,7 +2848,6 @@ def completion( # type: ignore # noqa: PLR0915
"top_k": kwargs.get("top_k", 40), "top_k": kwargs.get("top_k", 40),
}, },
}, },
verify=litellm.ssl_verify,
) )
response_json = resp.json() response_json = resp.json()
""" """

View file

@ -303,7 +303,7 @@ def run_server( # noqa: PLR0915
return return
if model and "ollama" in model and api_base is None: if model and "ollama" in model and api_base is None:
run_ollama_serve() run_ollama_serve()
import requests import httpx
if test_async is True: if test_async is True:
import concurrent import concurrent
@ -319,7 +319,7 @@ def run_server( # noqa: PLR0915
], ],
} }
response = requests.post("http://0.0.0.0:4000/queue/request", json=data) response = httpx.post("http://0.0.0.0:4000/queue/request", json=data)
response = response.json() response = response.json()
@ -327,7 +327,7 @@ def run_server( # noqa: PLR0915
try: try:
url = response["url"] url = response["url"]
polling_url = f"{api_base}{url}" polling_url = f"{api_base}{url}"
polling_response = requests.get(polling_url) polling_response = httpx.get(polling_url)
polling_response = polling_response.json() polling_response = polling_response.json()
print("\n RESPONSE FROM POLLING JOB", polling_response) # noqa print("\n RESPONSE FROM POLLING JOB", polling_response) # noqa
status = polling_response["status"] status = polling_response["status"]
@ -378,7 +378,7 @@ def run_server( # noqa: PLR0915
if health is not False: if health is not False:
print("\nLiteLLM: Health Testing models in config") # noqa print("\nLiteLLM: Health Testing models in config") # noqa
response = requests.get(url=f"http://{host}:{port}/health") response = httpx.get(url=f"http://{host}:{port}/health")
print(json.dumps(response.json(), indent=4)) # noqa print(json.dumps(response.json(), indent=4)) # noqa
return return
if test is not False: if test is not False:

View file

@ -11,9 +11,6 @@ import random
import traceback import traceback
from typing import Optional from typing import Optional
import dotenv # type: ignore
import requests
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger

View file

@ -6,10 +6,6 @@ import traceback
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import dotenv
import requests
from pydantic import BaseModel
from litellm import token_counter from litellm import token_counter
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache

View file

@ -43,7 +43,6 @@ import aiohttp
import dotenv import dotenv
import httpx import httpx
import openai import openai
import requests
import tiktoken import tiktoken
from httpx import Proxy from httpx import Proxy
from httpx._utils import get_environment_proxies from httpx._utils import get_environment_proxies
@ -4175,7 +4174,7 @@ def get_max_tokens(model: str) -> Optional[int]:
config_url = f"https://huggingface.co/{model_name}/raw/main/config.json" config_url = f"https://huggingface.co/{model_name}/raw/main/config.json"
try: try:
# Make the HTTP request to get the raw JSON file # Make the HTTP request to get the raw JSON file
response = requests.get(config_url) response = litellm.module_level_client.get(config_url)
response.raise_for_status() # Raise an exception for bad responses (4xx or 5xx) response.raise_for_status() # Raise an exception for bad responses (4xx or 5xx)
# Parse the JSON response # Parse the JSON response
@ -4186,7 +4185,7 @@ def get_max_tokens(model: str) -> Optional[int]:
return max_position_embeddings return max_position_embeddings
else: else:
return None return None
except requests.exceptions.RequestException: except Exception:
return None return None
try: try:
@ -4361,7 +4360,7 @@ def get_model_info( # noqa: PLR0915
try: try:
# Make the HTTP request to get the raw JSON file # Make the HTTP request to get the raw JSON file
response = requests.get(config_url) response = litellm.module_level_client.get(config_url)
response.raise_for_status() # Raise an exception for bad responses (4xx or 5xx) response.raise_for_status() # Raise an exception for bad responses (4xx or 5xx)
# Parse the JSON response # Parse the JSON response
@ -4374,7 +4373,7 @@ def get_model_info( # noqa: PLR0915
return max_position_embeddings return max_position_embeddings
else: else:
return None return None
except requests.exceptions.RequestException: except Exception:
return None return None
try: try:

View file

@ -0,0 +1,183 @@
"""
Prevent usage of 'requests' library in the codebase.
"""
import os
import ast
import sys
from typing import List, Tuple
def find_requests_usage(directory: str) -> List[Tuple[str, int, str]]:
"""
Recursively search for Python files in the given directory
and find usages of the 'requests' library.
Args:
directory (str): The root directory to search for Python files
Returns:
List of tuples containing (file_path, line_number, usage_type)
"""
requests_usages = []
def is_likely_requests_usage(node):
"""
More precise check to avoid false positives
"""
try:
# Convert node to string representation
node_str = ast.unparse(node)
# Specific checks to ensure it's the requests library
requests_identifiers = [
# HTTP methods
"requests.get",
"requests.post",
"requests.put",
"requests.delete",
"requests.head",
"requests.patch",
"requests.options",
"requests.request",
"requests.session",
# Types and exceptions
"requests.Response",
"requests.Request",
"requests.Session",
"requests.ConnectionError",
"requests.HTTPError",
"requests.Timeout",
"requests.TooManyRedirects",
"requests.RequestException",
# Additional modules and attributes
"requests.api",
"requests.exceptions",
"requests.models",
"requests.auth",
"requests.cookies",
"requests.structures",
]
# Check for specific requests library identifiers
return any(identifier in node_str for identifier in requests_identifiers)
except:
return False
def scan_file(file_path: str):
"""
Scan a single Python file for requests library usage
"""
try:
# Use utf-8-sig to handle files with BOM, ignore errors
with open(file_path, "r", encoding="utf-8-sig", errors="ignore") as file:
tree = ast.parse(file.read())
for node in ast.walk(tree):
# Check import statements
if isinstance(node, ast.Import):
for alias in node.names:
if alias.name == "requests":
requests_usages.append(
(file_path, node.lineno, f"Import: {alias.name}")
)
# Check import from statements
elif isinstance(node, ast.ImportFrom):
if node.module == "requests":
requests_usages.append(
(file_path, node.lineno, f"Import from: {node.module}")
)
# Check method calls
elif isinstance(node, ast.Call):
# More precise check for requests usage
try:
if is_likely_requests_usage(node.func):
requests_usages.append(
(
file_path,
node.lineno,
f"Method Call: {ast.unparse(node.func)}",
)
)
except:
pass
# Check attribute access
elif isinstance(node, ast.Attribute):
try:
# More precise check
if is_likely_requests_usage(node):
requests_usages.append(
(
file_path,
node.lineno,
f"Attribute Access: {ast.unparse(node)}",
)
)
except:
pass
except SyntaxError as e:
print(f"Syntax error in {file_path}: {e}", file=sys.stderr)
except Exception as e:
print(f"Error processing {file_path}: {e}", file=sys.stderr)
# Recursively walk through directory
for root, dirs, files in os.walk(directory):
# Remove virtual environment and cache directories from search
dirs[:] = [
d
for d in dirs
if not any(
venv in d
for venv in [
"venv",
"env",
"myenv",
".venv",
"__pycache__",
".pytest_cache",
]
)
]
for file in files:
if file.endswith(".py"):
full_path = os.path.join(root, file)
# Skip files in virtual environment or cache directories
if not any(
venv in full_path
for venv in [
"venv",
"env",
"myenv",
".venv",
"__pycache__",
".pytest_cache",
]
):
scan_file(full_path)
return requests_usages
def main():
# Get directory from command line argument or use current directory
directory = "../../litellm"
# Find requests library usages
results = find_requests_usage(directory)
# Print results
if results:
print("Requests Library Usages Found:")
for file_path, line_num, usage_type in results:
print(f"{file_path}:{line_num} - {usage_type}")
else:
print("No requests library usages found.")
if __name__ == "__main__":
main()

View file

@ -1940,10 +1940,11 @@ def test_ollama_image():
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status_code = 200 mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"} mock_response.headers = {"Content-Type": "application/json"}
data_json = json.loads(kwargs["data"])
mock_response.json.return_value = { mock_response.json.return_value = {
# return the image in the response so that it can be tested # return the image in the response so that it can be tested
# against the original # against the original
"response": kwargs["json"]["images"] "response": data_json["images"]
} }
return mock_response return mock_response
@ -1971,9 +1972,10 @@ def test_ollama_image():
[datauri_base64_data, datauri_base64_data], [datauri_base64_data, datauri_base64_data],
] ]
client = HTTPHandler()
for test in tests: for test in tests:
try: try:
with patch("requests.post", side_effect=mock_post): with patch.object(client, "post", side_effect=mock_post):
response = completion( response = completion(
model="ollama/llava", model="ollama/llava",
messages=[ messages=[
@ -1988,6 +1990,7 @@ def test_ollama_image():
], ],
} }
], ],
client=client,
) )
if not test[1]: if not test[1]:
# the conversion process may not always generate the same image, # the conversion process may not always generate the same image,
@ -2387,8 +2390,8 @@ def test_completion_ollama_hosted():
response = completion( response = completion(
model="ollama/phi", model="ollama/phi",
messages=messages, messages=messages,
max_tokens=2, max_tokens=20,
api_base="https://test-ollama-endpoint.onrender.com", # api_base="https://test-ollama-endpoint.onrender.com",
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)

View file

@ -606,14 +606,14 @@ def test_completion_azure_function_calling_stream():
@pytest.mark.skip("Flaky ollama test - needs to be fixed") @pytest.mark.skip("Flaky ollama test - needs to be fixed")
def test_completion_ollama_hosted_stream(): def test_completion_ollama_hosted_stream():
try: try:
litellm.set_verbose = True # litellm.set_verbose = True
response = completion( response = completion(
model="ollama/phi", model="ollama/phi",
messages=messages, messages=messages,
max_tokens=10, max_tokens=100,
num_retries=3, num_retries=3,
timeout=20, timeout=20,
api_base="https://test-ollama-endpoint.onrender.com", # api_base="https://test-ollama-endpoint.onrender.com",
stream=True, stream=True,
) )
# Add any assertions here to check the response # Add any assertions here to check the response