fix triton linting

This commit is contained in:
Ishaan Jaff 2024-07-23 11:03:34 -07:00
parent a8c88dad64
commit aba600a892

View file

@ -1,24 +1,27 @@
import os
import json import json
from enum import Enum import os
import requests # type: ignore
import time import time
from typing import Callable, Optional, List, Sequence, Any, Union, Dict from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import ( from litellm.utils import (
ModelResponse,
Choices, Choices,
CustomStreamWrapper,
Delta, Delta,
EmbeddingResponse,
Message,
ModelResponse,
Usage, Usage,
map_finish_reason, map_finish_reason,
CustomStreamWrapper,
Message,
EmbeddingResponse,
) )
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM from .base import BaseLLM
import httpx # type: ignore from .prompt_templates.factory import custom_prompt, prompt_factory
class TritonError(Exception): class TritonError(Exception):
@ -143,7 +146,7 @@ class TritonChatCompletion(BaseLLM):
logging_obj=None, logging_obj=None,
optional_params=None, optional_params=None,
client=None, client=None,
stream: bool = False, stream: Optional[bool] = False,
acompletion: bool = False, acompletion: bool = False,
) -> ModelResponse: ) -> ModelResponse:
type_of_model = "" type_of_model = ""
@ -220,12 +223,12 @@ class TritonChatCompletion(BaseLLM):
) )
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
data_for_triton = json.dumps(data_for_triton) json_data_for_triton: str = json.dumps(data_for_triton)
if acompletion: if acompletion:
return self.acompletion( return self.acompletion( # type: ignore
model, model,
data_for_triton, json_data_for_triton,
headers=headers, headers=headers,
logging_obj=logging_obj, logging_obj=logging_obj,
api_base=api_base, api_base=api_base,