fix REPLICATE

This commit is contained in:
Ishaan Jaff 2024-11-21 09:42:01 -08:00
parent fdaee84b82
commit 3d3d651b89

View file

@ -9,7 +9,10 @@ import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .prompt_templates.factory import custom_prompt, prompt_factory
@ -325,7 +328,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
async def async_handle_prediction_response_streaming(
prediction_url, api_token, print_verbose
):
http_handler = AsyncHTTPHandler(concurrent_limit=1)
http_handler = get_async_httpx_client(llm_provider=litellm.LlmProviders.REPLICATE)
previous_output = ""
output_string = ""
@ -560,7 +563,9 @@ async def async_completion(
logging_obj,
print_verbose,
) -> Union[ModelResponse, CustomStreamWrapper]:
http_handler = AsyncHTTPHandler(concurrent_limit=1)
http_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.REPLICATE,
)
prediction_url = await async_start_prediction(
version_id,
input_data,