(Fix) - linting errors

This commit is contained in:
Ishaan Jaff 2024-05-11 15:57:06 -07:00
parent b9b8bf52f3
commit 91a6a0eef4
3 changed files with 54 additions and 26 deletions

View file

@ -202,13 +202,11 @@ class BudgetExceededError(Exception):
## DEPRECATED ## ## DEPRECATED ##
class InvalidRequestError(openai.BadRequestError): # type: ignore class InvalidRequestError(openai.BadRequestError): # type: ignore
def __init__( def __init__(self, message, model, llm_provider):
self, message, model, llm_provider, response: Optional[httpx.Response] = None
):
self.status_code = 400 self.status_code = 400
self.message = message self.message = message
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
super().__init__( super().__init__(
self.message, response=response, body=None self.message, f"{self.model}"
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs

View file

@ -4,6 +4,7 @@ from litellm.proxy._types import UserAPIKeyAuth
import litellm, traceback, sys, uuid import litellm, traceback, sys, uuid
from fastapi import HTTPException from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from typing import Optional
class _PROXY_AzureContentSafety( class _PROXY_AzureContentSafety(
@ -71,7 +72,7 @@ class _PROXY_AzureContentSafety(
return result return result
async def test_violation(self, content: str, source: str = None): async def test_violation(self, content: str, source: Optional[str] = None):
verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content) verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content)
# Construct a request # Construct a request

View file

@ -35,6 +35,7 @@ from dataclasses import (
import litellm._service_logger # for storing API inputs, outputs, and metadata import litellm._service_logger # for storing API inputs, outputs, and metadata
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.caching import DualCache from litellm.caching import DualCache
oidc_cache = DualCache() oidc_cache = DualCache()
try: try:
@ -2957,7 +2958,7 @@ def client(original_function):
) )
else: else:
return result return result
return result return result
# Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print # Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print
@ -9559,16 +9560,20 @@ def get_secret(
if oidc_token is not None: if oidc_token is not None:
return oidc_token return oidc_token
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) oidc_client = HTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature # https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
response = client.get( response = oidc_client.get(
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity", "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
params={"audience": oidc_aud}, params={"audience": oidc_aud},
headers={"Metadata-Flavor": "Google"}, headers={"Metadata-Flavor": "Google"},
) )
if response.status_code == 200: if response.status_code == 200:
oidc_token = response.text oidc_token = response.text
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60) oidc_cache.set_cache(
key=secret_name, value=oidc_token, ttl=3600 - 60
)
return oidc_token return oidc_token
else: else:
raise ValueError("Google OIDC provider failed") raise ValueError("Google OIDC provider failed")
@ -9587,25 +9592,34 @@ def get_secret(
case "github": case "github":
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions # https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL") actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") actions_id_token_request_token = os.getenv(
if actions_id_token_request_url is None or actions_id_token_request_token is None: "ACTIONS_ID_TOKEN_REQUEST_TOKEN"
raise ValueError("ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment") )
if (
actions_id_token_request_url is None
or actions_id_token_request_token is None
):
raise ValueError(
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
)
oidc_token = oidc_cache.get_cache(key=secret_name) oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None: if oidc_token is not None:
return oidc_token return oidc_token
client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) oidc_client = HTTPHandler(
response = client.get( timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
response = oidc_client.get(
actions_id_token_request_url, actions_id_token_request_url,
params={"audience": oidc_aud}, params={"audience": oidc_aud},
headers={ headers={
"Authorization": f"Bearer {actions_id_token_request_token}", "Authorization": f"Bearer {actions_id_token_request_token}",
"Accept": "application/json; api-version=2.0", "Accept": "application/json; api-version=2.0",
}, },
) )
if response.status_code == 200: if response.status_code == 200:
oidc_token = response.text['value'] oidc_token = response.text["value"]
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5) oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
return oidc_token return oidc_token
else: else:
@ -9613,7 +9627,6 @@ def get_secret(
case _: case _:
raise ValueError("Unsupported OIDC provider") raise ValueError("Unsupported OIDC provider")
try: try:
if litellm.secret_manager_client is not None: if litellm.secret_manager_client is not None:
try: try:
@ -10562,7 +10575,12 @@ class CustomStreamWrapper:
response = chunk.replace("data: ", "").strip() response = chunk.replace("data: ", "").strip()
parsed_response = json.loads(response) parsed_response = json.loads(response)
else: else:
return {"text": "", "is_finished": False, "prompt_tokens": 0, "completion_tokens": 0} return {
"text": "",
"is_finished": False,
"prompt_tokens": 0,
"completion_tokens": 0,
}
else: else:
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})") print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
raise ValueError( raise ValueError(
@ -10583,19 +10601,32 @@ class CustomStreamWrapper:
return {"text": "", "is_finished": False} return {"text": "", "is_finished": False}
except Exception as e: except Exception as e:
raise e raise e
def handle_clarifai_completion_chunk(self, chunk): def handle_clarifai_completion_chunk(self, chunk):
try: try:
if isinstance(chunk, dict): if isinstance(chunk, dict):
parsed_response = chunk parsed_response = chunk
if isinstance(chunk, (str, bytes)): if isinstance(chunk, (str, bytes)):
if isinstance(chunk, bytes): if isinstance(chunk, bytes):
parsed_response = chunk.decode("utf-8") parsed_response = chunk.decode("utf-8")
else: else:
parsed_response = chunk parsed_response = chunk
data_json = json.loads(parsed_response) data_json = json.loads(parsed_response)
text = data_json.get("outputs", "")[0].get("data", "").get("text", "").get("raw","") text = (
prompt_tokens = len(encoding.encode(data_json.get("outputs", "")[0].get("input","").get("data", "").get("text", "").get("raw",""))) data_json.get("outputs", "")[0]
.get("data", "")
.get("text", "")
.get("raw", "")
)
prompt_tokens = len(
encoding.encode(
data_json.get("outputs", "")[0]
.get("input", "")
.get("data", "")
.get("text", "")
.get("raw", "")
)
)
completion_tokens = len(encoding.encode(text)) completion_tokens = len(encoding.encode(text))
return { return {
"text": text, "text": text,
@ -10650,9 +10681,7 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
elif ( elif self.custom_llm_provider and self.custom_llm_provider == "clarifai":
self.custom_llm_provider and self.custom_llm_provider == "clarifai"
):
response_obj = self.handle_clarifai_completion_chunk(chunk) response_obj = self.handle_clarifai_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
elif self.model == "replicate" or self.custom_llm_provider == "replicate": elif self.model == "replicate" or self.custom_llm_provider == "replicate":