From 91a6a0eef40c4f1a9f91a8e78a989f4091fc10c7 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 11 May 2024 15:57:06 -0700 Subject: [PATCH] (Fix) - linting errors --- litellm/exceptions.py | 6 +- litellm/proxy/hooks/azure_content_safety.py | 3 +- litellm/utils.py | 71 +++++++++++++++------ 3 files changed, 54 insertions(+), 26 deletions(-) diff --git a/litellm/exceptions.py b/litellm/exceptions.py index 7c3471acf..d239f1e12 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -202,13 +202,11 @@ class BudgetExceededError(Exception): ## DEPRECATED ## class InvalidRequestError(openai.BadRequestError): # type: ignore - def __init__( - self, message, model, llm_provider, response: Optional[httpx.Response] = None - ): + def __init__(self, message, model, llm_provider): self.status_code = 400 self.message = message self.model = model self.llm_provider = llm_provider super().__init__( - self.message, response=response, body=None + self.message, f"{self.model}" ) # Call the base class constructor with the parameters it needs diff --git a/litellm/proxy/hooks/azure_content_safety.py b/litellm/proxy/hooks/azure_content_safety.py index 433571c15..5b5139f8c 100644 --- a/litellm/proxy/hooks/azure_content_safety.py +++ b/litellm/proxy/hooks/azure_content_safety.py @@ -4,6 +4,7 @@ from litellm.proxy._types import UserAPIKeyAuth import litellm, traceback, sys, uuid from fastapi import HTTPException from litellm._logging import verbose_proxy_logger +from typing import Optional class _PROXY_AzureContentSafety( @@ -71,7 +72,7 @@ class _PROXY_AzureContentSafety( return result - async def test_violation(self, content: str, source: str = None): + async def test_violation(self, content: str, source: Optional[str] = None): verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content) # Construct a request diff --git a/litellm/utils.py b/litellm/utils.py index e787d2155..2eee41d8d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -35,6 +35,7 @@ from dataclasses import ( import litellm._service_logger # for storing API inputs, outputs, and metadata from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.caching import DualCache + oidc_cache = DualCache() try: @@ -2957,7 +2958,7 @@ def client(original_function): ) else: return result - + return result # Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print @@ -9559,16 +9560,20 @@ def get_secret( if oidc_token is not None: return oidc_token - client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) + oidc_client = HTTPHandler( + timeout=httpx.Timeout(timeout=600.0, connect=5.0) + ) # https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature - response = client.get( + response = oidc_client.get( "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity", params={"audience": oidc_aud}, headers={"Metadata-Flavor": "Google"}, ) if response.status_code == 200: oidc_token = response.text - oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60) + oidc_cache.set_cache( + key=secret_name, value=oidc_token, ttl=3600 - 60 + ) return oidc_token else: raise ValueError("Google OIDC provider failed") @@ -9587,25 +9592,34 @@ def get_secret( case "github": # https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL") - actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") - if actions_id_token_request_url is None or actions_id_token_request_token is None: - raise ValueError("ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment") + actions_id_token_request_token = os.getenv( + "ACTIONS_ID_TOKEN_REQUEST_TOKEN" + ) + if ( + actions_id_token_request_url is None + or actions_id_token_request_token is None + ): + raise ValueError( + "ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment" + ) oidc_token = oidc_cache.get_cache(key=secret_name) if oidc_token is not None: return oidc_token - client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) - response = client.get( + oidc_client = HTTPHandler( + timeout=httpx.Timeout(timeout=600.0, connect=5.0) + ) + response = oidc_client.get( actions_id_token_request_url, params={"audience": oidc_aud}, headers={ "Authorization": f"Bearer {actions_id_token_request_token}", "Accept": "application/json; api-version=2.0", - }, + }, ) if response.status_code == 200: - oidc_token = response.text['value'] + oidc_token = response.text["value"] oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5) return oidc_token else: @@ -9613,7 +9627,6 @@ def get_secret( case _: raise ValueError("Unsupported OIDC provider") - try: if litellm.secret_manager_client is not None: try: @@ -10562,7 +10575,12 @@ class CustomStreamWrapper: response = chunk.replace("data: ", "").strip() parsed_response = json.loads(response) else: - return {"text": "", "is_finished": False, "prompt_tokens": 0, "completion_tokens": 0} + return { + "text": "", + "is_finished": False, + "prompt_tokens": 0, + "completion_tokens": 0, + } else: print_verbose(f"chunk: {chunk} (Type: {type(chunk)})") raise ValueError( @@ -10583,19 +10601,32 @@ class CustomStreamWrapper: return {"text": "", "is_finished": False} except Exception as e: raise e - + def handle_clarifai_completion_chunk(self, chunk): try: if isinstance(chunk, dict): - parsed_response = chunk + parsed_response = chunk if isinstance(chunk, (str, bytes)): if isinstance(chunk, bytes): parsed_response = chunk.decode("utf-8") else: parsed_response = chunk - data_json = json.loads(parsed_response) - text = data_json.get("outputs", "")[0].get("data", "").get("text", "").get("raw","") - prompt_tokens = len(encoding.encode(data_json.get("outputs", "")[0].get("input","").get("data", "").get("text", "").get("raw",""))) + data_json = json.loads(parsed_response) + text = ( + data_json.get("outputs", "")[0] + .get("data", "") + .get("text", "") + .get("raw", "") + ) + prompt_tokens = len( + encoding.encode( + data_json.get("outputs", "")[0] + .get("input", "") + .get("data", "") + .get("text", "") + .get("raw", "") + ) + ) completion_tokens = len(encoding.encode(text)) return { "text": text, @@ -10650,9 +10681,7 @@ class CustomStreamWrapper: completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] - elif ( - self.custom_llm_provider and self.custom_llm_provider == "clarifai" - ): + elif self.custom_llm_provider and self.custom_llm_provider == "clarifai": response_obj = self.handle_clarifai_completion_chunk(chunk) completion_obj["content"] = response_obj["text"] elif self.model == "replicate" or self.custom_llm_provider == "replicate":