From f3016250147304cb5a52fe1e4e402a2140d7540b Mon Sep 17 00:00:00 2001 From: Steven Osborn Date: Tue, 25 Jun 2024 09:03:05 -0700 Subject: [PATCH 1/5] create litellm user to fix issue in k8s where prisma fails due to user nobody without home directory --- Dockerfile.database | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/Dockerfile.database b/Dockerfile.database index 22084bab8..1901200d5 100644 --- a/Dockerfile.database +++ b/Dockerfile.database @@ -9,6 +9,27 @@ FROM $LITELLM_BUILD_IMAGE as builder # Set the working directory to /app WORKDIR /app +ARG LITELLM_USER=litellm LITELLM_UID=1729 +ARG LITELLM_GROUP=litellm LITELLM_GID=1729 + +RUN groupadd \ + --gid ${LITELLM_GID} \ + ${LITELLM_GROUP} \ + && useradd \ + --create-home \ + --shell /bin/sh \ + --gid ${LITELLM_GID} \ + --uid ${LITELLM_UID} \ + ${LITELLM_USER} + +# Allows user to update python install. +# This is necessary for prisma. +RUN chown -R ${LITELLM_USER}:${LITELLM_GROUP} /usr/local/lib/python3.11 + +# Set the HOME var forcefully because of prisma. +ENV HOME=/home/${LITELLM_USER} +USER ${LITELLM_USER} + # Install build dependencies RUN apt-get clean && apt-get update && \ apt-get install -y gcc python3-dev && \ From 6889a4c0dd2275fc112dbe0badda42bd68f1adf0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 25 Jun 2024 13:47:38 -0700 Subject: [PATCH 2/5] fix(utils.py): predibase exception mapping - map 424 as a badrequest error --- litellm/llms/predibase.py | 39 +++++++++++++------------ litellm/proxy/_super_secret_config.yaml | 5 +++- litellm/utils.py | 12 +++----- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase.py index 8ad294457..7a137da70 100644 --- a/litellm/llms/predibase.py +++ b/litellm/llms/predibase.py @@ -1,27 +1,26 @@ # What is this? ## Controller file for Predibase Integration - https://predibase.com/ -from functools import partial -import os, types -import traceback +import copy import json -from enum import Enum -import requests, copy # type: ignore +import os import time -from typing import Callable, Optional, List, Literal, Union -from litellm.utils import ( - ModelResponse, - Usage, - CustomStreamWrapper, - Message, - Choices, -) -from litellm.litellm_core_utils.core_helpers import map_finish_reason -import litellm -from .prompt_templates.factory import prompt_factory, custom_prompt -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler -from .base import BaseLLM +import traceback +import types +from enum import Enum +from functools import partial +from typing import Callable, List, Literal, Optional, Union + import httpx # type: ignore +import requests # type: ignore + +import litellm +from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage + +from .base import BaseLLM +from .prompt_templates.factory import custom_prompt, prompt_factory class PredibaseError(Exception): @@ -496,7 +495,9 @@ class PredibaseChatCompletion(BaseLLM): except httpx.HTTPStatusError as e: raise PredibaseError( status_code=e.response.status_code, - message="HTTPStatusError - {}".format(e.response.text), + message="HTTPStatusError - received status_code={}, error_message={}".format( + e.response.status_code, e.response.text + ), ) except Exception as e: raise PredibaseError( diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index c5f1b4768..94df97c54 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -14,9 +14,10 @@ model_list: - model_name: fake-openai-endpoint litellm_params: model: predibase/llama-3-8b-instruct - api_base: "http://0.0.0.0:8000" + # api_base: "http://0.0.0.0:8081" api_key: os.environ/PREDIBASE_API_KEY tenant_id: os.environ/PREDIBASE_TENANT_ID + adapter_id: qwoiqjdoqin max_retries: 0 temperature: 0.1 max_new_tokens: 256 @@ -73,6 +74,8 @@ model_list: litellm_settings: callbacks: ["dynamic_rate_limiter"] + # success_callback: ["langfuse"] + # failure_callback: ["langfuse"] # default_team_settings: # - team_id: proj1 # success_callback: ["langfuse"] diff --git a/litellm/utils.py b/litellm/utils.py index 9f6ebaff0..00833003b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6157,13 +6157,6 @@ def exception_type( response=original_exception.response, litellm_debug_info=extra_information, ) - if "Request failed during generation" in error_str: - # this is an internal server error from predibase - raise litellm.InternalServerError( - message=f"PredibaseException - {error_str}", - llm_provider="predibase", - model=model, - ) elif hasattr(original_exception, "status_code"): if original_exception.status_code == 500: exception_mapping_worked = True @@ -6201,7 +6194,10 @@ def exception_type( llm_provider=custom_llm_provider, litellm_debug_info=extra_information, ) - elif original_exception.status_code == 422: + elif ( + original_exception.status_code == 422 + or original_exception.status_code == 424 + ): exception_mapping_worked = True raise BadRequestError( message=f"PredibaseException - {original_exception.message}", From 4c99010eeea1d21e1370326fd1dcd16133b4b99a Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 25 Jun 2024 13:55:54 -0700 Subject: [PATCH 3/5] fix - verify license without api request --- litellm/proxy/auth/litellm_license.py | 65 +++++++++++++++++++++++++++ litellm/proxy/auth/public_key.pem | 9 ++++ requirements.txt | 1 + 3 files changed, 75 insertions(+) create mode 100644 litellm/proxy/auth/public_key.pem diff --git a/litellm/proxy/auth/litellm_license.py b/litellm/proxy/auth/litellm_license.py index ffd9f5273..ec51f904c 100644 --- a/litellm/proxy/auth/litellm_license.py +++ b/litellm/proxy/auth/litellm_license.py @@ -1,6 +1,14 @@ # What is this? ## If litellm license in env, checks if it's valid +import base64 +import json import os +from datetime import datetime + +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa + +from litellm._logging import verbose_proxy_logger from litellm.llms.custom_httpx.http_handler import HTTPHandler @@ -15,6 +23,20 @@ class LicenseCheck: def __init__(self) -> None: self.license_str = os.getenv("LITELLM_LICENSE", None) self.http_handler = HTTPHandler() + self.public_key = None + self.read_public_key() + + def read_public_key(self): + # current dir + current_dir = os.path.dirname(os.path.realpath(__file__)) + + # check if public_key.pem exists + _path_to_public_key = os.path.join(current_dir, "public_key.pem") + if os.path.exists(_path_to_public_key): + with open(_path_to_public_key, "rb") as key_file: + self.public_key = serialization.load_pem_public_key(key_file.read()) + else: + self.public_key = None def _verify(self, license_str: str) -> bool: url = "{}/verify_license/{}".format(self.base_url, license_str) @@ -35,11 +57,54 @@ class LicenseCheck: return False def is_premium(self) -> bool: + """ + 1. verify_license_without_api_request: checks if license was generate using private / public key pair + 2. _verify: checks if license is valid calling litellm API. This is the old way we were generating/validating license + """ try: if self.license_str is None: return False + elif self.verify_license_without_api_request( + public_key=self.public_key, license_key=self.license_str + ): + return True elif self._verify(license_str=self.license_str): return True return False except Exception as e: return False + + def verify_license_without_api_request(self, public_key, license_key): + try: + # Decode the license key + decoded = base64.b64decode(license_key) + message, signature = decoded.split(b".", 1) + + # Verify the signature + public_key.verify( + signature, + message, + padding.PSS( + mgf=padding.MGF1(hashes.SHA256()), + salt_length=padding.PSS.MAX_LENGTH, + ), + hashes.SHA256(), + ) + + # Decode and parse the data + license_data = json.loads(message.decode()) + + # debug information provided in license data + verbose_proxy_logger.debug("License data: %s", license_data) + + # Check expiration date + expiration_date = datetime.strptime( + license_data["expiration_date"], "%Y-%m-%d" + ) + if expiration_date < datetime.now(): + return False, "License has expired" + + return True + + except Exception as e: + return False diff --git a/litellm/proxy/auth/public_key.pem b/litellm/proxy/auth/public_key.pem new file mode 100644 index 000000000..12a69dde2 --- /dev/null +++ b/litellm/proxy/auth/public_key.pem @@ -0,0 +1,9 @@ +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAmfBuNiNzDkNWyce23koQ +w0vq3bSVHkq7fd9Sw/U1q7FwRwL221daLTyGWssd8xAoQSFXAJKoBwzJQ9wd+o44 +lfL54E3a61nfjZuF+D9ntpXZFfEAxLVtIahDeQjUz4b/EpgciWIJyUfjCJrQo6LY +eyAZPTGSO8V3zHyaU+CFywq5XCuCnfZqCZeCw051St59A2v8W32mXSCJ+A+x0hYP +yXJyRRFcefSFG5IBuRHr4Y24Vx7NUIAoco5cnxJho9g2z3J/Hb0GKW+oBNvRVumk +nuA2Ljmjh4yI0OoTIW8ZWxemvCCJHSjdfKlMyb+QI4fmeiIUZzP5Au+F561Styqq +YQIDAQAB +-----END PUBLIC KEY----- diff --git a/requirements.txt b/requirements.txt index fbf2bfc1d..8c5e4ab3b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,6 +31,7 @@ azure-identity==1.16.1 # for azure content safety opentelemetry-api==1.25.0 opentelemetry-sdk==1.25.0 opentelemetry-exporter-otlp==1.25.0 +cryptography==42.0.7 ### LITELLM PACKAGE DEPENDENCIES python-dotenv==1.0.0 # for env From e813e984f74ea09ea92646c44c5a5ab7a30bbff0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 25 Jun 2024 16:03:47 -0700 Subject: [PATCH 4/5] fix(predibase.py): support json schema on predibase --- litellm/llms/predibase.py | 59 ++++++++++++++++++++++--- litellm/proxy/_super_secret_config.yaml | 16 +++---- litellm/utils.py | 10 ++++- 3 files changed, 67 insertions(+), 18 deletions(-) diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase.py index 7a137da70..534f8e26f 100644 --- a/litellm/llms/predibase.py +++ b/litellm/llms/predibase.py @@ -15,6 +15,8 @@ import httpx # type: ignore import requests # type: ignore import litellm +import litellm.litellm_core_utils +import litellm.litellm_core_utils.litellm_logging from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage @@ -145,7 +147,49 @@ class PredibaseConfig: } def get_supported_openai_params(self): - return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] + return [ + "stream", + "temperature", + "max_tokens", + "top_p", + "stop", + "n", + "response_format", + ] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None + if param == "temperature": + if value == 0.0 or value == 0: + # hugging face exception raised when temp==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive + value = 0.01 + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "n": + optional_params["best_of"] = value + optional_params["do_sample"] = ( + True # Need to sample if you want best of for hf inference endpoints + ) + if param == "stream": + optional_params["stream"] = value + if param == "stop": + optional_params["stop"] = value + if param == "max_tokens": + # HF TGI raises the following exception when max_new_tokens==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive + if value == 0: + value = 1 + optional_params["max_new_tokens"] = value + if param == "echo": + # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details + # Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False + optional_params["decoder_input_details"] = True + if param == "response_format": + optional_params["response_format"] = value + return optional_params class PredibaseChatCompletion(BaseLLM): @@ -224,15 +268,16 @@ class PredibaseChatCompletion(BaseLLM): status_code=response.status_code, ) else: - if ( - not isinstance(completion_response, dict) - or "generated_text" not in completion_response - ): + if not isinstance(completion_response, dict): raise PredibaseError( status_code=422, - message=f"response is not in expected format - {completion_response}", + message=f"'completion_response' is not a dictionary - {completion_response}", + ) + elif "generated_text" not in completion_response: + raise PredibaseError( + status_code=422, + message=f"'generated_text' is not a key response dictionary - {completion_response}", ) - if len(completion_response["generated_text"]) > 0: model_response["choices"][0]["message"]["content"] = self.output_parser( completion_response["generated_text"] diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index 94df97c54..2060f61ca 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -14,14 +14,10 @@ model_list: - model_name: fake-openai-endpoint litellm_params: model: predibase/llama-3-8b-instruct - # api_base: "http://0.0.0.0:8081" + api_base: "http://0.0.0.0:8081" api_key: os.environ/PREDIBASE_API_KEY tenant_id: os.environ/PREDIBASE_TENANT_ID - adapter_id: qwoiqjdoqin - max_retries: 0 - temperature: 0.1 max_new_tokens: 256 - return_full_text: false # - litellm_params: # api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ @@ -97,8 +93,8 @@ assistant_settings: router_settings: enable_pre_call_checks: true -general_settings: - alerting: ["slack"] - enable_jwt_auth: True - litellm_jwtauth: - team_id_jwt_field: "client_id" \ No newline at end of file +# general_settings: +# # alerting: ["slack"] +# enable_jwt_auth: True +# litellm_jwtauth: +# team_id_jwt_field: "client_id" \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 00833003b..4465c5b0a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2609,7 +2609,15 @@ def get_optional_params( optional_params["top_p"] = top_p if stop is not None: optional_params["stop_sequences"] = stop - elif custom_llm_provider == "huggingface" or custom_llm_provider == "predibase": + elif custom_llm_provider == "predibase": + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.PredibaseConfig().map_openai_params( + non_default_params=non_default_params, optional_params=optional_params + ) + elif custom_llm_provider == "huggingface": ## check if unsupported param passed in supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider From 4abb83b12dff6b9cada26d9ce86fa29ce3860c9c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 25 Jun 2024 16:28:47 -0700 Subject: [PATCH 5/5] fix only use crypto imports when needed --- litellm/proxy/auth/litellm_license.py | 31 ++++++++++++++++----------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/litellm/proxy/auth/litellm_license.py b/litellm/proxy/auth/litellm_license.py index ec51f904c..0310dcaf5 100644 --- a/litellm/proxy/auth/litellm_license.py +++ b/litellm/proxy/auth/litellm_license.py @@ -5,9 +5,6 @@ import json import os from datetime import datetime -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import padding, rsa - from litellm._logging import verbose_proxy_logger from litellm.llms.custom_httpx.http_handler import HTTPHandler @@ -27,16 +24,22 @@ class LicenseCheck: self.read_public_key() def read_public_key(self): - # current dir - current_dir = os.path.dirname(os.path.realpath(__file__)) + try: + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import padding, rsa - # check if public_key.pem exists - _path_to_public_key = os.path.join(current_dir, "public_key.pem") - if os.path.exists(_path_to_public_key): - with open(_path_to_public_key, "rb") as key_file: - self.public_key = serialization.load_pem_public_key(key_file.read()) - else: - self.public_key = None + # current dir + current_dir = os.path.dirname(os.path.realpath(__file__)) + + # check if public_key.pem exists + _path_to_public_key = os.path.join(current_dir, "public_key.pem") + if os.path.exists(_path_to_public_key): + with open(_path_to_public_key, "rb") as key_file: + self.public_key = serialization.load_pem_public_key(key_file.read()) + else: + self.public_key = None + except Exception as e: + verbose_proxy_logger.error(f"Error reading public key: {str(e)}") def _verify(self, license_str: str) -> bool: url = "{}/verify_license/{}".format(self.base_url, license_str) @@ -76,6 +79,9 @@ class LicenseCheck: def verify_license_without_api_request(self, public_key, license_key): try: + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import padding, rsa + # Decode the license key decoded = base64.b64decode(license_key) message, signature = decoded.split(b".", 1) @@ -107,4 +113,5 @@ class LicenseCheck: return True except Exception as e: + verbose_proxy_logger.error(str(e)) return False