diff --git a/litellm/proxy/config.yaml b/litellm/proxy/config.yaml index 4bd0f42f1..7aecf772e 100644 --- a/litellm/proxy/config.yaml +++ b/litellm/proxy/config.yaml @@ -2,27 +2,16 @@ model_list: - model_name: zephyr-alpha litellm_params: # params for litellm.completion() - https://docs.litellm.ai/docs/completion/input#input---request-body model: huggingface/HuggingFaceH4/zephyr-7b-alpha - max_tokens: 20 - temperature: 0 - - model_name: gpt-4-team1 + api_base: http://0.0.0.0:8001 + - model_name: zephyr-beta litellm_params: - model: azure/chatgpt-v-2 - api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ - api_version: "2023-05-15" - azure_ad_token: eyJ0eXAiOiJ - - model_name: gpt-4-team2 - litellm_params: - model: azure/gpt-4 - api_key: sk-123 - api_base: https://openai-gpt-4-test-v-2.openai.azure.com/ - - model_name: gpt-4-team3 - litellm_params: - model: azure/gpt-4 - api_key: sk-123 - - model_name: ollama/zephyr - litellm_params: - model: ollama/zephyr + model: huggingface/HuggingFaceH4/zephyr-7b-beta + api_base: https:// litellm_settings: drop_params: True - success_callback: ["langfuse"] # https://docs.litellm.ai/docs/observability/langfuse_integration \ No newline at end of file + set_verbose: True + +general_settings: + master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234) + database_url: "postgresql://:@:/" # [OPTIONAL] use for token-based auth to proxy \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index be5570745..e51df68da 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1,7 +1,9 @@ -import sys, os, platform, time, copy +import sys, os, platform, time, copy, re import threading, ast import shutil, random, traceback, requests +from datetime import datetime, timedelta from typing import Optional +import secrets, subprocess messages: list = [] sys.path.insert( 0, os.path.abspath("../..") @@ -16,7 +18,6 @@ try: import backoff import yaml except ImportError: - import subprocess import sys subprocess.check_call( @@ -214,6 +215,14 @@ def save_params_to_config(data: dict): with open(user_config_path, "wb") as f: tomli_w.dump(config, f) +def prisma_setup(database_url: Optional[str]): + if database_url: + subprocess.run(['pip', 'install', 'prisma']) + subprocess.run(['python3', '-m', 'pip', 'install', 'prisma']) + subprocess.run(['prisma', 'db', 'push']) + # Now you can import the Prisma Client + from prisma import Client + def load_router_config(router: Optional[litellm.Router], config_file_path: str): global master_key @@ -230,10 +239,22 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): print(f"Loaded config YAML:\n{json.dumps(config, indent=2)}") + ## ENVIRONMENT VARIABLES + environment_variables = config.get('environment_variables', None) + if environment_variables: + for key, value in environment_variables.items(): + os.environ[key] = value + + ## GENERAL SERVER SETTINGS (e.g. master key,..) general_settings = config.get("general_settings", None) if general_settings: + ### MASTER KEY ### master_key = general_settings.get("master_key", None) + ### CONNECT TO DATABASE ### + database_url = general_settings.get("database_url", None) + prisma_setup(database_url=database_url) + ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) litellm_settings = config.get('litellm_settings', None) @@ -250,12 +271,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): print(f"\033[32m {model.get('model_name', '')}\033[0m") print() - ## ENVIRONMENT VARIABLES - environment_variables = config.get('environment_variables', None) - if environment_variables: - for key, value in environment_variables.items(): - os.environ[key] = value - return router, model_list, server_settings def load_config(): @@ -585,29 +600,54 @@ async def chat_completion(request: Request, model: Optional[str] = None): detail=error_msg ) +@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)]) +async def generate_key(request: Request): + data = await request.json() -@router.post("/router/chat/completions", dependencies=[Depends(user_api_key_auth)]) -async def router_completion(request: Request): - try: - body = await request.body() - body_str = body.decode() - try: - data = ast.literal_eval(body_str) - except: - data = json.loads(body_str) - return {"data": data} - except Exception as e: - print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`") - error_traceback = traceback.format_exc() - error_msg = f"{str(e)}\n\n{error_traceback}" - try: - status = e.status_code # type: ignore - except: - status = status.HTTP_500_INTERNAL_SERVER_ERROR, - raise HTTPException( - status_code=status, - detail=error_msg + token = f"sk-{secrets.token_urlsafe(16)}" + duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided + + def _duration_in_seconds(duration: str): + match = re.match(r"(\d+)([smhd]?)", duration) + if not match: + raise ValueError("Invalid duration format") + + value, unit = match.groups() + value = int(value) + + if unit == "s": + return value + elif unit == "m": + return value * 60 + elif unit == "h": + return value * 3600 + elif unit == "d": + return value * 86400 + else: + raise ValueError("Unsupported duration unit") + + duration = _duration_in_seconds(duration=duration_str) + expires = datetime.utcnow() + timedelta(seconds=duration) + try: + from prisma import Client + db = Client() + await db.connect() + # Create a new verification token (you may want to enhance this logic based on your needs) + print(dir(db)) + verification_token_data = { + "token": token, + "expires": expires + } + new_verification_token = await db.litellm_verificationtoken.create( + {**verification_token_data} ) + print(f"new_verification_token: {new_verification_token}") + except Exception as e: + traceback.print_exc() + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + + return {"token": new_verification_token.token, "expires": new_verification_token.expires} + @router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)]) async def retrieve_server_log(request: Request): diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma new file mode 100644 index 000000000..d4b603db1 --- /dev/null +++ b/litellm/proxy/schema.prisma @@ -0,0 +1,14 @@ +datasource client { + provider = "postgresql" + url = env("DATABASE_URL") +} + +generator client { + provider = "prisma-client-py" +} + +// required for token gen +model LiteLLM_VerificationToken { + token String @unique + expires DateTime +} \ No newline at end of file