From aba600a892922ae99290d09e20c7a78962a6371d Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 23 Jul 2024 11:03:34 -0700 Subject: [PATCH] fix triton linting --- litellm/llms/triton.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/litellm/llms/triton.py b/litellm/llms/triton.py index 770898949..7d0338d06 100644 --- a/litellm/llms/triton.py +++ b/litellm/llms/triton.py @@ -1,24 +1,27 @@ -import os import json -from enum import Enum -import requests # type: ignore +import os import time -from typing import Callable, Optional, List, Sequence, Any, Union, Dict +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +import httpx # type: ignore +import requests # type: ignore + +import litellm +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.utils import ( - ModelResponse, Choices, + CustomStreamWrapper, Delta, + EmbeddingResponse, + Message, + ModelResponse, Usage, map_finish_reason, - CustomStreamWrapper, - Message, - EmbeddingResponse, ) -import litellm -from .prompt_templates.factory import prompt_factory, custom_prompt -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler + from .base import BaseLLM -import httpx # type: ignore +from .prompt_templates.factory import custom_prompt, prompt_factory class TritonError(Exception): @@ -143,7 +146,7 @@ class TritonChatCompletion(BaseLLM): logging_obj=None, optional_params=None, client=None, - stream: bool = False, + stream: Optional[bool] = False, acompletion: bool = False, ) -> ModelResponse: type_of_model = "" @@ -220,12 +223,12 @@ class TritonChatCompletion(BaseLLM): ) headers = {"Content-Type": "application/json"} - data_for_triton = json.dumps(data_for_triton) + json_data_for_triton: str = json.dumps(data_for_triton) if acompletion: - return self.acompletion( + return self.acompletion( # type: ignore model, - data_for_triton, + json_data_for_triton, headers=headers, logging_obj=logging_obj, api_base=api_base,