Merge pull request #4796 from BerriAI/litellm_refactor_requests_factory

fix(factory.py): refactor factory to use httpx client
This commit is contained in:
Krish Dholakia 2024-07-19 21:07:41 -07:00 committed by GitHub
commit 156d445597
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -6,7 +6,6 @@ import xml.etree.ElementTree as ET
from enum import Enum
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
import requests
from jinja2 import BaseLoader, Template, exceptions, meta
from jinja2.sandbox import ImmutableSandboxedEnvironment
@ -14,6 +13,7 @@ import litellm
import litellm.types
import litellm.types.llms
import litellm.types.llms.vertex_ai
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.types.completion import (
ChatCompletionFunctionMessageParam,
ChatCompletionMessageParam,
@ -364,7 +364,8 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
)
# Make a GET request to fetch the JSON data
response = requests.get(url)
client = HTTPHandler(concurrent_limit=1)
response = client.get(url)
if response.status_code == 200:
# Parse the JSON data
tokenizer_config = json.loads(response.content)
@ -494,7 +495,8 @@ def claude_2_1_pt(
def get_model_info(token, model):
try:
headers = {"Authorization": f"Bearer {token}"}
response = requests.get("https://api.together.xyz/models/info", headers=headers)
client = HTTPHandler(concurrent_limit=1)
response = client.get("https://api.together.xyz/models/info", headers=headers)
if response.status_code == 200:
model_info = response.json()
for m in model_info:
@ -657,11 +659,11 @@ def construct_tool_use_system_prompt(
def convert_url_to_base64(url):
import base64
import requests
client = HTTPHandler(concurrent_limit=1)
for _ in range(3):
try:
response = requests.get(url)
response = client.get(url)
break
except:
pass
@ -1798,7 +1800,8 @@ def _load_image_from_url(image_url):
try:
# Send a GET request to the image URL
response = requests.get(image_url)
client = HTTPHandler(concurrent_limit=1)
response = client.get(image_url)
response.raise_for_status() # Raise an exception for HTTP errors
# Check the response's content type to ensure it is an image
@ -1811,8 +1814,6 @@ def _load_image_from_url(image_url):
# Load the image from the response content
return Image.open(BytesIO(response.content))
except requests.RequestException as e:
raise Exception(f"Request failed: {e}")
except Exception as e:
raise e
@ -1989,8 +1990,9 @@ def get_image_details(image_url) -> Tuple[str, str]:
try:
import base64
client = HTTPHandler(concurrent_limit=1)
# Send a GET request to the image URL
response = requests.get(image_url)
response = client.get(image_url)
response.raise_for_status() # Raise an exception for HTTP errors
# Check the response's content type to ensure it is an image
@ -2010,8 +2012,6 @@ def get_image_details(image_url) -> Tuple[str, str]:
return base64_bytes, mime_type
except requests.RequestException as e:
raise Exception(f"Request failed: {e}")
except Exception as e:
raise e