fix REPLICATE

This commit is contained in:
Ishaan Jaff 2024-11-21 09:42:01 -08:00
parent fdaee84b82
commit 3d3d651b89

View file

@ -9,7 +9,10 @@ import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
import litellm import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .prompt_templates.factory import custom_prompt, prompt_factory from .prompt_templates.factory import custom_prompt, prompt_factory
@ -325,7 +328,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
async def async_handle_prediction_response_streaming( async def async_handle_prediction_response_streaming(
prediction_url, api_token, print_verbose prediction_url, api_token, print_verbose
): ):
http_handler = AsyncHTTPHandler(concurrent_limit=1) http_handler = get_async_httpx_client(llm_provider=litellm.LlmProviders.REPLICATE)
previous_output = "" previous_output = ""
output_string = "" output_string = ""
@ -560,7 +563,9 @@ async def async_completion(
logging_obj, logging_obj,
print_verbose, print_verbose,
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
http_handler = AsyncHTTPHandler(concurrent_limit=1) http_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.REPLICATE,
)
prediction_url = await async_start_prediction( prediction_url = await async_start_prediction(
version_id, version_id,
input_data, input_data,