build(pyproject.toml): add new dev dependencies - for type checking (#9631)

* build(pyproject.toml): add new dev dependencies - for type checking

* build: reformat files to fit black

* ci: reformat to fit black

* ci(test-litellm.yml): make tests run clear

* build(pyproject.toml): add ruff

* fix: fix ruff checks

* build(mypy/): fix mypy linting errors

* fix(hashicorp_secret_manager.py): fix passing cert for tls auth

* build(mypy/): resolve all mypy errors

* test: update test

* fix: fix black formatting

* build(pre-commit-config.yaml): use poetry run black

* fix(proxy_server.py): fix linting error

* fix: fix ruff safe representation error
This commit is contained in:
Krish Dholakia 2025-03-29 11:02:13 -07:00 committed by GitHub
parent 95e5dfae5a
commit 9b7ebb6a7d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
214 changed files with 1553 additions and 1433 deletions

View file

@ -1,6 +1,6 @@
import json
from copy import deepcopy
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, List, Optional, Union, cast
import httpx
@ -35,7 +35,6 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
# set os.environ['AWS_REGION_NAME'] = <your-region_name>
class SagemakerLLM(BaseAWSLLM):
def _load_credentials(
self,
optional_params: dict,
@ -154,7 +153,6 @@ class SagemakerLLM(BaseAWSLLM):
acompletion: bool = False,
headers: dict = {},
):
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
credentials, aws_region_name = self._load_credentials(optional_params)
inference_params = deepcopy(optional_params)
@ -437,10 +435,14 @@ class SagemakerLLM(BaseAWSLLM):
prepared_request.headers.update(
{"X-Amzn-SageMaker-Inference-Component": model_id}
)
if not prepared_request.body:
raise ValueError("Prepared request body is empty")
completion_stream = await self.make_async_call(
api_base=prepared_request.url,
headers=prepared_request.headers, # type: ignore
data=prepared_request.body,
data=cast(str, prepared_request.body),
logging_obj=logging_obj,
)
streaming_response = CustomStreamWrapper(
@ -625,7 +627,7 @@ class SagemakerLLM(BaseAWSLLM):
response = client.invoke_endpoint(
EndpointName={model},
ContentType="application/json",
Body={data}, # type: ignore
Body=f"{data!r}", # Use !r for safe representation
CustomAttributes="accept_eula=true",
)""" # type: ignore
logging_obj.pre_call(