Merge pull request #4796 from BerriAI/litellm_refactor_requests_factory

fix(factory.py): refactor factory to use httpx client
This commit is contained in:
Krish Dholakia 2024-07-19 21:07:41 -07:00 committed by GitHub
commit 156d445597
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -6,7 +6,6 @@ import xml.etree.ElementTree as ET
from enum import Enum from enum import Enum
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
import requests
from jinja2 import BaseLoader, Template, exceptions, meta from jinja2 import BaseLoader, Template, exceptions, meta
from jinja2.sandbox import ImmutableSandboxedEnvironment from jinja2.sandbox import ImmutableSandboxedEnvironment
@ -14,6 +13,7 @@ import litellm
import litellm.types import litellm.types
import litellm.types.llms import litellm.types.llms
import litellm.types.llms.vertex_ai import litellm.types.llms.vertex_ai
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.types.completion import ( from litellm.types.completion import (
ChatCompletionFunctionMessageParam, ChatCompletionFunctionMessageParam,
ChatCompletionMessageParam, ChatCompletionMessageParam,
@ -364,7 +364,8 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json" f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
) )
# Make a GET request to fetch the JSON data # Make a GET request to fetch the JSON data
response = requests.get(url) client = HTTPHandler(concurrent_limit=1)
response = client.get(url)
if response.status_code == 200: if response.status_code == 200:
# Parse the JSON data # Parse the JSON data
tokenizer_config = json.loads(response.content) tokenizer_config = json.loads(response.content)
@ -494,7 +495,8 @@ def claude_2_1_pt(
def get_model_info(token, model): def get_model_info(token, model):
try: try:
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
response = requests.get("https://api.together.xyz/models/info", headers=headers) client = HTTPHandler(concurrent_limit=1)
response = client.get("https://api.together.xyz/models/info", headers=headers)
if response.status_code == 200: if response.status_code == 200:
model_info = response.json() model_info = response.json()
for m in model_info: for m in model_info:
@ -657,11 +659,11 @@ def construct_tool_use_system_prompt(
def convert_url_to_base64(url): def convert_url_to_base64(url):
import base64 import base64
import requests client = HTTPHandler(concurrent_limit=1)
for _ in range(3): for _ in range(3):
try: try:
response = requests.get(url)
response = client.get(url)
break break
except: except:
pass pass
@ -1798,7 +1800,8 @@ def _load_image_from_url(image_url):
try: try:
# Send a GET request to the image URL # Send a GET request to the image URL
response = requests.get(image_url) client = HTTPHandler(concurrent_limit=1)
response = client.get(image_url)
response.raise_for_status() # Raise an exception for HTTP errors response.raise_for_status() # Raise an exception for HTTP errors
# Check the response's content type to ensure it is an image # Check the response's content type to ensure it is an image
@ -1811,8 +1814,6 @@ def _load_image_from_url(image_url):
# Load the image from the response content # Load the image from the response content
return Image.open(BytesIO(response.content)) return Image.open(BytesIO(response.content))
except requests.RequestException as e:
raise Exception(f"Request failed: {e}")
except Exception as e: except Exception as e:
raise e raise e
@ -1989,8 +1990,9 @@ def get_image_details(image_url) -> Tuple[str, str]:
try: try:
import base64 import base64
client = HTTPHandler(concurrent_limit=1)
# Send a GET request to the image URL # Send a GET request to the image URL
response = requests.get(image_url) response = client.get(image_url)
response.raise_for_status() # Raise an exception for HTTP errors response.raise_for_status() # Raise an exception for HTTP errors
# Check the response's content type to ensure it is an image # Check the response's content type to ensure it is an image
@ -2010,8 +2012,6 @@ def get_image_details(image_url) -> Tuple[str, str]:
return base64_bytes, mime_type return base64_bytes, mime_type
except requests.RequestException as e:
raise Exception(f"Request failed: {e}")
except Exception as e: except Exception as e:
raise e raise e