fix(factory.py): refactor factory to use httpx client

This commit is contained in:
Krrish Dholakia 2024-07-19 15:35:05 -07:00
parent 4c4f032a75
commit 757dedd4c8

View file

@ -6,7 +6,6 @@ import xml.etree.ElementTree as ET
from enum import Enum from enum import Enum
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
import requests
from jinja2 import BaseLoader, Template, exceptions, meta from jinja2 import BaseLoader, Template, exceptions, meta
from jinja2.sandbox import ImmutableSandboxedEnvironment from jinja2.sandbox import ImmutableSandboxedEnvironment
@ -365,7 +364,8 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json" f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
) )
# Make a GET request to fetch the JSON data # Make a GET request to fetch the JSON data
response = requests.get(url) client = HTTPHandler(concurrent_limit=1)
response = client.get(url)
if response.status_code == 200: if response.status_code == 200:
# Parse the JSON data # Parse the JSON data
tokenizer_config = json.loads(response.content) tokenizer_config = json.loads(response.content)
@ -495,7 +495,8 @@ def claude_2_1_pt(
def get_model_info(token, model): def get_model_info(token, model):
try: try:
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
response = requests.get("https://api.together.xyz/models/info", headers=headers) client = HTTPHandler(concurrent_limit=1)
response = client.get("https://api.together.xyz/models/info", headers=headers)
if response.status_code == 200: if response.status_code == 200:
model_info = response.json() model_info = response.json()
for m in model_info: for m in model_info:
@ -658,11 +659,11 @@ def construct_tool_use_system_prompt(
def convert_url_to_base64(url): def convert_url_to_base64(url):
import base64 import base64
import requests client = HTTPHandler(concurrent_limit=1)
for _ in range(3): for _ in range(3):
try: try:
response = requests.get(url)
response = client.get(url)
break break
except: except:
pass pass
@ -1799,7 +1800,8 @@ def _load_image_from_url(image_url):
try: try:
# Send a GET request to the image URL # Send a GET request to the image URL
response = requests.get(image_url) client = HTTPHandler(concurrent_limit=1)
response = client.get(image_url)
response.raise_for_status() # Raise an exception for HTTP errors response.raise_for_status() # Raise an exception for HTTP errors
# Check the response's content type to ensure it is an image # Check the response's content type to ensure it is an image
@ -1812,8 +1814,6 @@ def _load_image_from_url(image_url):
# Load the image from the response content # Load the image from the response content
return Image.open(BytesIO(response.content)) return Image.open(BytesIO(response.content))
except requests.RequestException as e:
raise Exception(f"Request failed: {e}")
except Exception as e: except Exception as e:
raise e raise e
@ -2012,8 +2012,6 @@ def get_image_details(image_url) -> Tuple[str, str]:
return base64_bytes, mime_type return base64_bytes, mime_type
except requests.RequestException as e:
raise Exception(f"Request failed: {e}")
except Exception as e: except Exception as e:
raise e raise e