diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e8bb1ff66..cc41d85f1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,11 +16,11 @@ repos: name: Check if files match entry: python3 ci_cd/check_files_match.py language: system -# - repo: local -# hooks: -# - id: mypy -# name: mypy -# entry: python3 -m mypy --ignore-missing-imports -# language: system -# types: [python] -# files: ^litellm/ \ No newline at end of file +- repo: local + hooks: + - id: mypy + name: mypy + entry: python3 -m mypy --ignore-missing-imports + language: system + types: [python] + files: ^litellm/ \ No newline at end of file diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index b250f3013..a2c4457c2 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -399,10 +399,11 @@ class Huggingface(BaseLLM): data = { "inputs": prompt, "parameters": optional_params, - "stream": ( + "stream": ( # type: ignore True if "stream" in optional_params - and optional_params["stream"] == True + and isinstance(optional_params["stream"], bool) + and optional_params["stream"] == True # type: ignore else False ), } @@ -433,7 +434,7 @@ class Huggingface(BaseLLM): data = { "inputs": prompt, "parameters": inference_params, - "stream": ( + "stream": ( # type: ignore True if "stream" in optional_params and optional_params["stream"] == True diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase.py index f3935984d..ef9c6b0ba 100644 --- a/litellm/llms/predibase.py +++ b/litellm/llms/predibase.py @@ -129,7 +129,7 @@ class PredibaseChatCompletion(BaseLLM): ) super().__init__() - def validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict: + def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict: if api_key is None: raise ValueError( "Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params" @@ -309,7 +309,7 @@ class PredibaseChatCompletion(BaseLLM): logger_fn=None, headers: dict = {}, ) -> Union[ModelResponse, CustomStreamWrapper]: - headers = self.validate_environment(api_key, headers) + headers = self._validate_environment(api_key, headers) completion_url = "" input_text = "" base_url = "https://serving.app.predibase.com" @@ -411,13 +411,13 @@ class PredibaseChatCompletion(BaseLLM): data=json.dumps(data), stream=stream, ) - response = CustomStreamWrapper( + _response = CustomStreamWrapper( response.iter_lines(), model, custom_llm_provider="predibase", logging_obj=logging_obj, ) - return response + return _response ### SYNC COMPLETION else: response = requests.post(