From 04e5963b653e705823992db01654649d74291c69 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Sat, 5 Oct 2024 21:26:51 -0400 Subject: [PATCH] Litellm expose disable schema update flag (#6085) * fix: enable new 'disable_prisma_schema_update' flag * build(config.yml): remove setup remote docker step * ci(config.yml): give container time to start up * ci(config.yml): update test * build(config.yml): actually start docker * build(config.yml): simplify grep check * fix(prisma_client.py): support reading disable_schema_update via env vars * ci(config.yml): add test to check if all general settings are documented * build(test_General_settings.py): check available dir * ci: check ../ repo path * build: check ./ * build: fix test --- .circleci/config.yml | 54 +++- docs/my-website/docs/proxy/configs.md | 39 ++- litellm/proxy/_new_secret_config.yaml | 2 +- litellm/proxy/db/check_migration.py | 102 +++++++ litellm/proxy/db/prisma_client.py | 17 ++ .../example_config_yaml/bad_schema.prisma | 265 ++++++++++++++++++ .../disable_schema_update.yaml | 12 + litellm/proxy/prisma_migration.py | 5 + litellm/proxy/proxy_cli.py | 48 ++-- litellm/proxy/utils.py | 22 +- .../test_general_setting_keys.py | 76 +++++ 11 files changed, 598 insertions(+), 44 deletions(-) create mode 100644 litellm/proxy/db/check_migration.py create mode 100644 litellm/proxy/example_config_yaml/bad_schema.prisma create mode 100644 litellm/proxy/example_config_yaml/disable_schema_update.yaml create mode 100644 tests/documentation_tests/test_general_setting_keys.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 748bf14f7..1e0918c3a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -299,7 +299,7 @@ jobs: ls python -m pytest -vv tests/local_testing/test_python_38.py - check_code_quality: + check_code_and_doc_quality: docker: - image: cimg/python:3.11 auth: @@ -319,7 +319,46 @@ jobs: pip install . - run: python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1) - run: ruff check ./litellm - + - run: python ./tests/documentation_tests/test_general_setting_keys.py + + db_migration_disable_update_check: + machine: + image: ubuntu-2204:2023.10.1 + resource_class: xlarge + working_directory: ~/project + steps: + - checkout + - run: + name: Build Docker image + command: | + docker build -t myapp . + - run: + name: Run Docker container + command: | + docker run --name my-app \ + -p 4000:4000 \ + -e DATABASE_URL=$PROXY_DATABASE_URL \ + -e DISABLE_SCHEMA_UPDATE="True" \ + -v $(pwd)/litellm/proxy/example_config_yaml/bad_schema.prisma:/app/schema.prisma \ + -v $(pwd)/litellm/proxy/example_config_yaml/bad_schema.prisma:/app/litellm/proxy/schema.prisma \ + -v $(pwd)/litellm/proxy/example_config_yaml/disable_schema_update.yaml:/app/config.yaml \ + myapp:latest \ + --config /app/config.yaml \ + --port 4000 > docker_output.log 2>&1 || true + - run: + name: Display Docker logs + command: cat docker_output.log + - run: + name: Check for expected error + command: | + if grep -q "prisma schema out of sync with db. Consider running these sql_commands to sync the two" docker_output.log; then + echo "Expected error found. Test passed." + else + echo "Expected error not found. Test failed." + cat docker_output.log + exit 1 + fi + build_and_test: machine: image: ubuntu-2204:2023.10.1 @@ -827,7 +866,7 @@ workflows: only: - main - /litellm_.*/ - - check_code_quality: + - check_code_and_doc_quality: filters: branches: only: @@ -869,6 +908,12 @@ workflows: only: - main - /litellm_.*/ + - db_migration_disable_update_check: + filters: + branches: + only: + - main + - /litellm_.*/ - installing_litellm_on_python: filters: branches: @@ -890,11 +935,12 @@ workflows: - litellm_router_testing - litellm_assistants_api_testing - ui_endpoint_testing + - db_migration_disable_update_check - e2e_ui_testing - installing_litellm_on_python - proxy_logging_guardrails_model_info_tests - proxy_pass_through_endpoint_tests - - check_code_quality + - check_code_and_doc_quality filters: branches: only: diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index c89c6ac0c..e73db18ad 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -674,7 +674,44 @@ general_settings: | alerting | array of strings | List of alerting methods [Doc on Slack Alerting](alerting) | | alerting_threshold | integer | The threshold for triggering alerts [Doc on Slack Alerting](alerting) | | use_client_credentials_pass_through_routes | boolean | If true, uses client credentials for all pass-through routes. [Doc on pass through routes](pass_through) | - +| health_check_details | boolean | If false, hides health check details (e.g. remaining rate limit). [Doc on health checks](health) | +| public_routes | List[str] | (Enterprise Feature) Control list of public routes | +| alert_types | List[str] | Control list of alert types to send to slack (Doc on alert types)[./alerting.md] | +| enforced_params | List[str] | (Enterprise Feature) List of params that must be included in all requests to the proxy | +| enable_oauth2_auth | boolean | (Enterprise Feature) If true, enables oauth2.0 authentication | +| use_x_forwarded_for | str | If true, uses the X-Forwarded-For header to get the client IP address | +| service_account_settings | List[Dict[str, Any]] | Set `service_account_settings` if you want to create settings that only apply to service account keys (Doc on service accounts)[./service_accounts.md] | +| image_generation_model | str | The default model to use for image generation - ignores model set in request | +| store_model_in_db | boolean | If true, allows `/model/new` endpoint to store model information in db. Endpoint disabled by default. [Doc on `/model/new` endpoint](./model_management.md#create-a-new-model) | +| max_request_size_mb | int | The maximum size for requests in MB. Requests above this size will be rejected. | +| max_response_size_mb | int | The maximum size for responses in MB. LLM Responses above this size will not be sent. | +| proxy_budget_rescheduler_min_time | int | The minimum time (in seconds) to wait before checking db for budget resets. | +| proxy_budget_rescheduler_max_time | int | The maximum time (in seconds) to wait before checking db for budget resets. | +| proxy_batch_write_at | int | Time (in seconds) to wait before batch writing spend logs to the db. | +| alerting_args | dict | Args for Slack Alerting [Doc on Slack Alerting](./alerting.md) | +| custom_key_generate | str | Custom function for key generation [Doc on custom key generation](./virtual_keys.md#custom--key-generate) | +| allowed_ips | List[str] | List of IPs allowed to access the proxy. If not set, all IPs are allowed. | +| embedding_model | str | The default model to use for embeddings - ignores model set in request | +| default_team_disabled | boolean | If true, users cannot create 'personal' keys (keys with no team_id). | +| alert_to_webhook_url | Dict[str] | [Specify a webhook url for each alert type.](./alerting.md#set-specific-slack-channels-per-alert-type) | +| key_management_settings | List[Dict[str, Any]] | Settings for key management system (e.g. AWS KMS, Azure Key Vault) [Doc on key management](../secret.md) | +| allow_user_auth | boolean | (Deprecated) old approach for user authentication. | +| user_api_key_cache_ttl | int | The time (in seconds) to cache user api keys in memory. | +| disable_prisma_schema_update | boolean | If true, turns off automatic schema updates to DB | +| litellm_key_header_name | str | If set, allows passing LiteLLM keys as a custom header. [Doc on custom headers](./virtual_keys.md#custom-headers) | +| moderation_model | str | The default model to use for moderation. | +| custom_sso | str | Path to a python file that implements custom SSO logic. [Doc on custom SSO](./custom_sso.md) | +| allow_client_side_credentials | boolean | If true, allows passing client side credentials to the proxy. (Useful when testing finetuning models) [Doc on client side credentials](./virtual_keys.md#client-side-credentials) | +| admin_only_routes | List[str] | (Enterprise Feature) List of routes that are only accessible to admin users. [Doc on admin only routes](./enterprise#control-available-public-private-routes) | +| use_azure_key_vault | boolean | If true, load keys from azure key vault | +| use_google_kms | boolean | If true, load keys from google kms | +| spend_report_frequency | str | Specify how often you want a Spend Report to be sent (e.g. "1d", "2d", "30d") [More on this](./alerting.md#spend-report-frequency) | +| ui_access_mode | Literal["admin_only"] | If set, restricts access to the UI to admin users only. [Docs](./ui.md#restrict-ui-access) | +| litellm_jwtauth | Dict[str, Any] | Settings for JWT authentication. [Docs](./token_auth.md) | +| litellm_license | str | The license key for the proxy. [Docs](../enterprise.md#how-does-deployment-with-enterprise-license-work) | +| oauth2_config_mappings | Dict[str, str] | Define the OAuth2 config mappings | +| pass_through_endpoints | List[Dict[str, Any]] | Define the pass through endpoints. [Docs](./pass_through) | +| enable_oauth2_proxy_auth | boolean | (Enterprise Feature) If true, enables oauth2.0 authentication | ### router_settings - Reference ```yaml diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index e2afa3f12..9f9e02fe9 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -2,4 +2,4 @@ model_list: - model_name: claude-3-5-sonnet-20240620 litellm_params: model: anthropic/claude-3-5-sonnet-20240620 - api_key: os.environ/ANTHROPIC_API_KEY \ No newline at end of file + api_key: os.environ/ANTHROPIC_API_KEY diff --git a/litellm/proxy/db/check_migration.py b/litellm/proxy/db/check_migration.py new file mode 100644 index 000000000..4ce12325f --- /dev/null +++ b/litellm/proxy/db/check_migration.py @@ -0,0 +1,102 @@ +"""Module for checking differences between Prisma schema and database.""" + +import os +import subprocess +from typing import List, Optional, Tuple + + +def extract_sql_commands(diff_output: str) -> List[str]: + """ + Extract SQL commands from the Prisma migrate diff output. + Args: + diff_output (str): The full output from prisma migrate diff. + Returns: + List[str]: A list of SQL commands extracted from the diff output. + """ + # Split the output into lines and remove empty lines + lines = [line.strip() for line in diff_output.split("\n") if line.strip()] + + sql_commands = [] + current_command = "" + in_sql_block = False + + for line in lines: + if line.startswith("-- "): # Comment line, likely a table operation description + if in_sql_block and current_command: + sql_commands.append(current_command.strip()) + current_command = "" + in_sql_block = True + elif in_sql_block: + if line.endswith(";"): + current_command += line + sql_commands.append(current_command.strip()) + current_command = "" + in_sql_block = False + else: + current_command += line + " " + + # Add any remaining command + if current_command: + sql_commands.append(current_command.strip()) + + return sql_commands + + +def check_prisma_schema_diff_helper(db_url: str) -> Tuple[bool, List[str]]: + """Checks for differences between current database and Prisma schema. + Returns: + A tuple containing: + - A boolean indicating if differences were found (True) or not (False). + - A string with the diff output or error message. + Raises: + subprocess.CalledProcessError: If the Prisma command fails. + Exception: For any other errors during execution. + """ + try: + result = subprocess.run( + [ + "prisma", + "migrate", + "diff", + "--from-url", + db_url, + "--to-schema-datamodel", + "./schema.prisma", + "--script", + ], + capture_output=True, + text=True, + check=True, + ) + + # return True, "Migration diff generated successfully." + sql_commands = extract_sql_commands(result.stdout) + + if sql_commands: + print("Changes to DB Schema detected") # noqa: T201 + print("Required SQL commands:") # noqa: T201 + for command in sql_commands: + print(command) # noqa: T201 + return True, sql_commands + else: + print("No changes required.") # noqa: T201 + return False, [] + except subprocess.CalledProcessError as e: + error_message = f"Failed to generate migration diff. Error: {e.stderr}" + print(error_message) # noqa: T201 + return False, [] + + +def check_prisma_schema_diff(db_url: Optional[str] = None) -> None: + """Main function to run the Prisma schema diff check.""" + if db_url is None: + db_url = os.getenv("DATABASE_URL") + if db_url is None: + raise Exception("DATABASE_URL not set") + has_diff, message = check_prisma_schema_diff_helper(db_url) + if has_diff: + raise Exception( + "prisma schema out of sync with db. Consider running these sql_commands to sync the two - {}".format( + message + ) + ) diff --git a/litellm/proxy/db/prisma_client.py b/litellm/proxy/db/prisma_client.py index 5e7fc4f79..76e425bf2 100644 --- a/litellm/proxy/db/prisma_client.py +++ b/litellm/proxy/db/prisma_client.py @@ -1,3 +1,7 @@ +""" +This file contains the PrismaWrapper class, which is used to wrap the Prisma client and handle the RDS IAM token. +""" + import asyncio import os import urllib @@ -5,6 +9,8 @@ import urllib.parse from datetime import datetime, timedelta from typing import Any, Callable, Optional +from litellm.secret_managers.main import str_to_bool + class PrismaWrapper: def __init__(self, original_prisma: Any, iam_token_db_auth: bool): @@ -104,3 +110,14 @@ class PrismaWrapper: raise ValueError("Failed to get RDS IAM token") return original_attr + + +def should_update_schema(disable_prisma_schema_update: Optional[bool]): + """ + This function is used to determine if the Prisma schema should be updated. + """ + if disable_prisma_schema_update is None: + disable_prisma_schema_update = str_to_bool(os.getenv("DISABLE_SCHEMA_UPDATE")) + if disable_prisma_schema_update is True: + return False + return True diff --git a/litellm/proxy/example_config_yaml/bad_schema.prisma b/litellm/proxy/example_config_yaml/bad_schema.prisma new file mode 100644 index 000000000..5c631406a --- /dev/null +++ b/litellm/proxy/example_config_yaml/bad_schema.prisma @@ -0,0 +1,265 @@ +datasource client { + provider = "postgresql" + url = env("DATABASE_URL") +} + +generator client { + provider = "prisma-client-py" +} + +// Budget / Rate Limits for an org +model LiteLLM_BudgetTable { + budget_id String @id @default(uuid()) + max_budget Float? + soft_budget Float? + max_parallel_requests Int? + tpm_limit BigInt? + rpm_limit BigInt? + model_max_budget Json? + temp_verification_token String? // bad param for testing + budget_duration String? + budget_reset_at DateTime? + created_at DateTime @default(now()) @map("created_at") + created_by String + updated_at DateTime @default(now()) @updatedAt @map("updated_at") + updated_by String + organization LiteLLM_OrganizationTable[] // multiple orgs can have the same budget + keys LiteLLM_VerificationToken[] // multiple keys can have the same budget + end_users LiteLLM_EndUserTable[] // multiple end-users can have the same budget + team_membership LiteLLM_TeamMembership[] // budgets of Users within a Team +} + +// Models on proxy +model LiteLLM_ProxyModelTable { + model_id String @id @default(uuid()) + model_name String + litellm_params Json + model_info Json? + created_at DateTime @default(now()) @map("created_at") + created_by String + updated_at DateTime @default(now()) @updatedAt @map("updated_at") + updated_by String +} + +model LiteLLM_OrganizationTable { + organization_id String @id @default(uuid()) + organization_alias String + budget_id String + metadata Json @default("{}") + models String[] + spend Float @default(0.0) + model_spend Json @default("{}") + created_at DateTime @default(now()) @map("created_at") + created_by String + updated_at DateTime @default(now()) @updatedAt @map("updated_at") + updated_by String + litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) + teams LiteLLM_TeamTable[] + users LiteLLM_UserTable[] +} + +// Model info for teams, just has model aliases for now. +model LiteLLM_ModelTable { + id Int @id @default(autoincrement()) + model_aliases Json? @map("aliases") + created_at DateTime @default(now()) @map("created_at") + created_by String + updated_at DateTime @default(now()) @updatedAt @map("updated_at") + updated_by String + team LiteLLM_TeamTable? +} + + +// Assign prod keys to groups, not individuals +model LiteLLM_TeamTable { + team_id String @id @default(uuid()) + team_alias String? + organization_id String? + admins String[] + members String[] + members_with_roles Json @default("{}") + metadata Json @default("{}") + max_budget Float? + spend Float @default(0.0) + models String[] + max_parallel_requests Int? + tpm_limit BigInt? + rpm_limit BigInt? + budget_duration String? + budget_reset_at DateTime? + blocked Boolean @default(false) + created_at DateTime @default(now()) @map("created_at") + updated_at DateTime @default(now()) @updatedAt @map("updated_at") + model_spend Json @default("{}") + model_max_budget Json @default("{}") + model_id Int? @unique // id for LiteLLM_ModelTable -> stores team-level model aliases + litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) + litellm_model_table LiteLLM_ModelTable? @relation(fields: [model_id], references: [id]) +} + +// Track spend, rate limit, budget Users +model LiteLLM_UserTable { + user_id String @id + user_alias String? + team_id String? + organization_id String? + password String? + teams String[] @default([]) + user_role String? + max_budget Float? + spend Float @default(0.0) + user_email String? + models String[] + metadata Json @default("{}") + max_parallel_requests Int? + tpm_limit BigInt? + rpm_limit BigInt? + budget_duration String? + budget_reset_at DateTime? + allowed_cache_controls String[] @default([]) + model_spend Json @default("{}") + model_max_budget Json @default("{}") + litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) + invitations_created LiteLLM_InvitationLink[] @relation("CreatedBy") + invitations_updated LiteLLM_InvitationLink[] @relation("UpdatedBy") + invitations_user LiteLLM_InvitationLink[] @relation("UserId") +} + +// Generate Tokens for Proxy +model LiteLLM_VerificationToken { + token String @id + key_name String? + key_alias String? + soft_budget_cooldown Boolean @default(false) // key-level state on if budget alerts need to be cooled down + spend Float @default(0.0) + expires DateTime? + models String[] + aliases Json @default("{}") + config Json @default("{}") + user_id String? + team_id String? + permissions Json @default("{}") + max_parallel_requests Int? + metadata Json @default("{}") + blocked Boolean? + tpm_limit BigInt? + rpm_limit BigInt? + max_budget Float? + budget_duration String? + budget_reset_at DateTime? + allowed_cache_controls String[] @default([]) + model_spend Json @default("{}") + model_max_budget Json @default("{}") + budget_id String? + litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) +} + +model LiteLLM_EndUserTable { + user_id String @id + alias String? // admin-facing alias + spend Float @default(0.0) + allowed_model_region String? // require all user requests to use models in this specific region + default_model String? // use along with 'allowed_model_region'. if no available model in region, default to this model. + budget_id String? + litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) + blocked Boolean @default(false) +} + +// store proxy config.yaml +model LiteLLM_Config { + param_name String @id + param_value Json? +} + +// View spend, model, api_key per request +model LiteLLM_SpendLogs { + request_id String @id + call_type String + api_key String @default ("") // Hashed API Token. Not the actual Virtual Key. Equivalent to 'token' column in LiteLLM_VerificationToken + spend Float @default(0.0) + total_tokens Int @default(0) + prompt_tokens Int @default(0) + completion_tokens Int @default(0) + startTime DateTime // Assuming start_time is a DateTime field + endTime DateTime // Assuming end_time is a DateTime field + completionStartTime DateTime? // Assuming completionStartTime is a DateTime field + model String @default("") + model_id String? @default("") // the model id stored in proxy model db + model_group String? @default("") // public model_name / model_group + api_base String? @default("") + user String? @default("") + metadata Json? @default("{}") + cache_hit String? @default("") + cache_key String? @default("") + request_tags Json? @default("[]") + team_id String? + end_user String? + requester_ip_address String? + @@index([startTime]) + @@index([end_user]) +} + +// View spend, model, api_key per request +model LiteLLM_ErrorLogs { + request_id String @id @default(uuid()) + startTime DateTime // Assuming start_time is a DateTime field + endTime DateTime // Assuming end_time is a DateTime field + api_base String @default("") + model_group String @default("") // public model_name / model_group + litellm_model_name String @default("") // model passed to litellm + model_id String @default("") // ID of model in ProxyModelTable + request_kwargs Json @default("{}") + exception_type String @default("") + exception_string String @default("") + status_code String @default("") +} + +// Beta - allow team members to request access to a model +model LiteLLM_UserNotifications { + request_id String @id + user_id String + models String[] + justification String + status String // approved, disapproved, pending +} + +model LiteLLM_TeamMembership { + // Use this table to track the Internal User's Spend within a Team + Set Budgets, rpm limits for the user within the team + user_id String + team_id String + spend Float @default(0.0) + budget_id String? + litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) + @@id([user_id, team_id]) +} + +model LiteLLM_InvitationLink { + // use this table to track invite links sent by admin for people to join the proxy + id String @id @default(uuid()) + user_id String + is_accepted Boolean @default(false) + accepted_at DateTime? // when link is claimed (user successfully onboards via link) + expires_at DateTime // till when is link valid + created_at DateTime // when did admin create the link + created_by String // who created the link + updated_at DateTime // when was invite status updated + updated_by String // who updated the status (admin/user who accepted invite) + + // Relations + liteLLM_user_table_user LiteLLM_UserTable @relation("UserId", fields: [user_id], references: [user_id]) + liteLLM_user_table_created LiteLLM_UserTable @relation("CreatedBy", fields: [created_by], references: [user_id]) + liteLLM_user_table_updated LiteLLM_UserTable @relation("UpdatedBy", fields: [updated_by], references: [user_id]) +} + + +model LiteLLM_AuditLog { + id String @id @default(uuid()) + updated_at DateTime @default(now()) + changed_by String @default("") // user or system that performed the action + changed_by_api_key String @default("") // api key hash that performed the action + action String // create, update, delete + table_name String // on of LitellmTableNames.TEAM_TABLE_NAME, LitellmTableNames.USER_TABLE_NAME, LitellmTableNames.PROXY_MODEL_TABLE_NAME, + object_id String // id of the object being audited. This can be the key id, team id, user id, model id + before_value Json? // value of the row + updated_values Json? // value of the row after change +} diff --git a/litellm/proxy/example_config_yaml/disable_schema_update.yaml b/litellm/proxy/example_config_yaml/disable_schema_update.yaml new file mode 100644 index 000000000..cc56b9516 --- /dev/null +++ b/litellm/proxy/example_config_yaml/disable_schema_update.yaml @@ -0,0 +1,12 @@ +model_list: + - model_name: fake-openai-endpoint + litellm_params: + model: openai/fake + api_key: fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + +litellm_settings: + callbacks: ["gcs_bucket"] + +general_settings: + disable_prisma_schema_update: true diff --git a/litellm/proxy/prisma_migration.py b/litellm/proxy/prisma_migration.py index 51dab42fa..ff26151df 100644 --- a/litellm/proxy/prisma_migration.py +++ b/litellm/proxy/prisma_migration.py @@ -47,6 +47,11 @@ retry_count = 0 max_retries = 3 exit_code = 1 +disable_schema_update = os.getenv("DISABLE_SCHEMA_UPDATE") +if disable_schema_update is not None and disable_schema_update == "True": + print("Skipping schema update...") # noqa + exit(0) + while retry_count < max_retries and exit_code != 0: retry_count += 1 print(f"Attempt {retry_count}...") # noqa diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 9f889d2a2..5db4f2d0a 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -461,6 +461,7 @@ def run_server( db_connection_pool_limit = 100 db_connection_timeout = 60 + general_settings = {} ### GET DB TOKEN FOR IAM AUTH ### if iam_token_db_auth: @@ -646,24 +647,37 @@ def run_server( is_prisma_runnable = False if is_prisma_runnable: - for _ in range(4): - # run prisma db push, before starting server - # Save the current working directory - original_dir = os.getcwd() - # set the working directory to where this script is - abspath = os.path.abspath(__file__) - dname = os.path.dirname(abspath) - os.chdir(dname) - try: - subprocess.run(["prisma", "db", "push", "--accept-data-loss"]) - break # Exit the loop if the subprocess succeeds - except subprocess.CalledProcessError as e: - import time + from litellm.proxy.db.check_migration import check_prisma_schema_diff + from litellm.proxy.db.prisma_client import should_update_schema - print(f"Error: {e}") # noqa - time.sleep(random.randrange(start=1, stop=5)) - finally: - os.chdir(original_dir) + if ( + should_update_schema( + general_settings.get("disable_prisma_schema_update") + ) + is False + ): + check_prisma_schema_diff(db_url=None) + else: + for _ in range(4): + # run prisma db push, before starting server + # Save the current working directory + original_dir = os.getcwd() + # set the working directory to where this script is + abspath = os.path.abspath(__file__) + dname = os.path.dirname(abspath) + os.chdir(dname) + try: + subprocess.run( + ["prisma", "db", "push", "--accept-data-loss"] + ) + break # Exit the loop if the subprocess succeeds + except subprocess.CalledProcessError as e: + import time + + print(f"Error: {e}") # noqa + time.sleep(random.randrange(start=1, stop=5)) + finally: + os.chdir(original_dir) else: print( # noqa f"Unable to connect to DB. DATABASE_URL found in environment, but prisma package not found." # noqa diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index a53da4512..aaccd56a3 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1055,27 +1055,7 @@ class PrismaClient: try: from prisma import Prisma # type: ignore except Exception: - os.environ["DATABASE_URL"] = database_url - # Save the current working directory - original_dir = os.getcwd() - # set the working directory to where this script is - abspath = os.path.abspath(__file__) - dname = os.path.dirname(abspath) - os.chdir(dname) - - try: - subprocess.run(["prisma", "generate"]) - subprocess.run( - ["prisma", "db", "push", "--accept-data-loss"] - ) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss - except Exception as e: - raise Exception( - f"Unable to run prisma commands. Run `pip install prisma` Got Exception: {(str(e))}" - ) - finally: - os.chdir(original_dir) - # Now you can import the Prisma Client - from prisma import Prisma # type: ignore + raise Exception("Unable to find Prisma binaries.") verbose_proxy_logger.debug("Connecting Prisma Client to DB..") if http_client is not None: self.db = PrismaWrapper( diff --git a/tests/documentation_tests/test_general_setting_keys.py b/tests/documentation_tests/test_general_setting_keys.py new file mode 100644 index 000000000..296c5c403 --- /dev/null +++ b/tests/documentation_tests/test_general_setting_keys.py @@ -0,0 +1,76 @@ +import os +import re + +# Define the base directory for the litellm repository and documentation path +repo_base = "./litellm" # Change this to your actual path + + +# Regular expressions to capture the keys used in general_settings.get() and general_settings[] +get_pattern = re.compile( + r'general_settings\.get\(\s*[\'"]([^\'"]+)[\'"](,?\s*[^)]*)?\)' +) +bracket_pattern = re.compile(r'general_settings\[\s*[\'"]([^\'"]+)[\'"]\s*\]') + +# Set to store unique keys from the code +general_settings_keys = set() + +# Walk through all files in the litellm repo to find references of general_settings +for root, dirs, files in os.walk(repo_base): + for file in files: + if file.endswith(".py"): # Only process Python files + file_path = os.path.join(root, file) + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + # Find all keys using general_settings.get() + get_matches = get_pattern.findall(content) + general_settings_keys.update( + match[0] for match in get_matches + ) # Extract only the key part + + # Find all keys using general_settings[] + bracket_matches = bracket_pattern.findall(content) + general_settings_keys.update(bracket_matches) + +# Parse the documentation to extract documented keys +repo_base = "./" +print(os.listdir(repo_base)) +docs_path = "./docs/my-website/docs/proxy/configs.md" # Path to the documentation +documented_keys = set() +try: + with open(docs_path, "r", encoding="utf-8") as docs_file: + content = docs_file.read() + + # Find the section titled "general_settings - Reference" + general_settings_section = re.search( + r"### general_settings - Reference(.*?)###", content, re.DOTALL + ) + if general_settings_section: + # Extract the table rows, which contain the documented keys + table_content = general_settings_section.group(1) + doc_key_pattern = re.compile( + r"\|\s*([^\|]+?)\s*\|" + ) # Capture the key from each row of the table + documented_keys.update(doc_key_pattern.findall(table_content)) +except Exception as e: + raise Exception( + f"Error reading documentation: {e}, \n repo base - {os.listdir(repo_base)}" + ) + +# Compare and find undocumented keys +undocumented_keys = general_settings_keys - documented_keys + +# Print results +print("Keys expected in 'general_settings' (found in code):") +for key in sorted(general_settings_keys): + print(key) + +if undocumented_keys: + raise Exception( + f"\nKeys not documented in 'general_settings - Reference': {undocumented_keys}" + ) +else: + print( + "\nAll keys are documented in 'general_settings - Reference'. - {}".format( + general_settings_keys + ) + )