fix(predibase.py): fix event loop closed error

This commit is contained in:
Krrish Dholakia 2024-05-09 19:07:19 -07:00
parent 491e177348
commit 76d4290591

View file

@ -124,9 +124,6 @@ class PredibaseConfig:
class PredibaseChatCompletion(BaseLLM): class PredibaseChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=litellm.request_timeout, connect=5.0)
)
super().__init__() super().__init__()
def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict: def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict:
@ -457,8 +454,10 @@ class PredibaseChatCompletion(BaseLLM):
logger_fn=None, logger_fn=None,
headers={}, headers={},
) -> ModelResponse: ) -> ModelResponse:
async_handler = AsyncHTTPHandler(
response = await self.async_handler.post( timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data) api_base, headers=headers, data=json.dumps(data)
) )
return self.process_response( return self.process_response(
@ -491,9 +490,11 @@ class PredibaseChatCompletion(BaseLLM):
logger_fn=None, logger_fn=None,
headers={}, headers={},
) -> CustomStreamWrapper: ) -> CustomStreamWrapper:
async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
data["stream"] = True data["stream"] = True
response = await self.async_handler.post( response = await async_handler.post(
url="https://serving.app.predibase.com/c4768f95/deployments/v2/llms/llama-3-8b-instruct/generate_stream", url="https://serving.app.predibase.com/c4768f95/deployments/v2/llms/llama-3-8b-instruct/generate_stream",
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),