From 4bdbf1e0e69c309ce60fd67032e696891a9d585e Mon Sep 17 00:00:00 2001 From: khalifan-kfan Date: Mon, 11 Aug 2025 11:41:48 +0300 Subject: [PATCH 1/2] remove all possibilities of SQL injections --- app/helpers/database_service.py | 754 ++++++++++++++++++++++---------- app/routes.py | 3 +- 2 files changed, 518 insertions(+), 239 deletions(-) diff --git a/app/helpers/database_service.py b/app/helpers/database_service.py index c0d3651..609d005 100644 --- a/app/helpers/database_service.py +++ b/app/helpers/database_service.py @@ -1,21 +1,102 @@ - import mysql.connector as mysql_conn import psycopg2 from psycopg2 import sql import secrets import string +import re from types import SimpleNamespace from config import settings +class SQLInjectionProtector: + """SQL injection protection and validation class""" + + # Valid patterns for different database objects + DB_NAME_PATTERN = re.compile(r'^[a-zA-Z][a-zA-Z0-9_]{0,63}$') + USERNAME_PATTERN = re.compile(r'^[a-zA-Z][a-zA-Z0-9_]{0,31}$') + + # Dangerous SQL keywords and patterns + DANGEROUS_PATTERNS = [ + r'(\b(DROP|DELETE|INSERT|UPDATE|ALTER|CREATE|TRUNCATE|EXEC|EXECUTE|UNION|SELECT)\b)', + r'(--|#|/\*|\*/)', # SQL comments + r'(\b(OR|AND)\s+\d+\s*=\s*\d+)', # Common injection patterns + r'(\bUNION\s+SELECT\b)', + r'(\';|\"\;)', # Statement terminators + ] + + @classmethod + def validate_identifier(cls, identifier, identifier_type="general"): + """ + Validate database identifiers + """ + if not identifier or not isinstance(identifier, str): + raise ValueError(f"Invalid {identifier_type}: must be a non-empty string") + + identifier = identifier.strip() + + if identifier_type == "database": + if not cls.DB_NAME_PATTERN.match(identifier): + raise ValueError(f"Invalid database name: {identifier}. Must start with letter, contain only alphanumeric and underscore, max 64 chars") + elif identifier_type == "username": + if not cls.USERNAME_PATTERN.match(identifier): + raise ValueError(f"Invalid username: {identifier}. Must start with letter, contain only alphanumeric and underscore, max 32 chars") + + # Check for dangerous patterns + identifier_upper = identifier.upper() + for pattern in cls.DANGEROUS_PATTERNS: + if re.search(pattern, identifier_upper, re.IGNORECASE): + raise ValueError(f"Potentially dangerous pattern detected in {identifier_type}: {identifier}") + + return identifier + + @classmethod + def validate_password(cls, password): + """ + Validate password - less restrictive but still check for obvious SQL injection + """ + if not password or not isinstance(password, str): + raise ValueError("Password must be a non-empty string") + + # Check for statement terminators and comment patterns that could be problematic + dangerous_password_patterns = [ + r'(--|#)', # SQL comments in passwords are suspicious + r'(\';|\")', # Statement terminators + ] + + for pattern in dangerous_password_patterns: + if re.search(pattern, password): + raise ValueError("Password contains potentially dangerous characters") + + return password + + @classmethod + def sanitize_all_inputs(cls, db_name=None, username=None, password=None): + """ + Centralized input sanitization for all database operations + """ + sanitized = {} + + if db_name is not None: + sanitized['db_name'] = cls.validate_identifier(db_name, "database") + + if username is not None: + sanitized['username'] = cls.validate_identifier(username, "username") + + if password is not None: + sanitized['password'] = cls.validate_password(password) + + return sanitized + + def generate_db_credentials(): - punctuation = r"""#%+,-<=>^_""" + # Use safer character set for passwords - avoid potential SQL metacharacters + safe_punctuation = r"#%+<=>^_" name = ''.join((secrets.choice(string.ascii_letters) for i in range(24))) user = ''.join((secrets.choice(string.ascii_letters) for i in range(16))) password = ''.join((secrets.choice( - string.ascii_letters + string.digits + punctuation) for i in range(32))) + string.ascii_letters + string.digits + safe_punctuation) for i in range(32))) return SimpleNamespace( user=user.lower(), @@ -28,6 +109,7 @@ class DatabaseService: def __init__(self): self.Error = None + self.protector = SQLInjectionProtector() def create_connection(self): """ Create a connection to db server """ @@ -39,6 +121,7 @@ def create_db_connection(self, user=None, password=None, db_name=None): def check_user_db_rights(self, user=None, password=None, db_name=None): """Verify user rights to db""" + pass # Create or check user exists database def create_database(self, db_name=None, user=None, password=None): @@ -85,7 +168,7 @@ def get_all_users(self): class MysqlDbService(DatabaseService): def __init__(self): - super(DatabaseService, self).__init__() + super().__init__() self.Error = mysql_conn.Error def create_connection(self): @@ -103,199 +186,256 @@ def create_connection(self): def create_db_connection(self, user=None, password=None, db_name=None): try: + # Sanitize inputs + sanitized = self.protector.sanitize_all_inputs( + db_name=db_name, username=user, password=password + ) + user_connection = mysql_conn.connect( host=settings.ADMIN_MYSQL_HOST, - user=user, - password=password, + user=sanitized['username'], + password=sanitized['password'], port=settings.ADMIN_MYSQL_PORT, - database=db_name + database=sanitized['db_name'] ) return user_connection - except self.Error as e: + except (self.Error, ValueError) as e: print(e) return False def check_db_connection(self): + super_connection = None try: super_connection = self.create_connection() if not super_connection: return False return True - except self.Error as e: + except self.Error: return False finally: - if not super_connection: - return False - if (super_connection.is_connected()): + if super_connection and super_connection.is_connected(): super_connection.close() def check_user_db_rights(self, user=None, password=None, db_name=None): + user_connection = None try: user_connection = self.create_db_connection( user=user, password=password, db_name=db_name) if not user_connection: return False return True - except self.Error as e: + except (self.Error, ValueError) as e: print(e) return False finally: - if not user_connection: - return False - if (user_connection.is_connected()): + if user_connection and user_connection.is_connected(): user_connection.close() - # Create or check user exists database def create_database(self, db_name=None, user=None, password=None): + connection = None + cursor = None try: + # Sanitize all inputs + sanitized = self.protector.sanitize_all_inputs( + db_name=db_name, username=user, password=password + ) + connection = self.create_connection() if not connection: return False + cursor = connection.cursor() - cursor.execute(f"CREATE DATABASE {db_name}") + + # Use parameterized queries where possible, identifiers need to be validated + cursor.execute(f"CREATE DATABASE `{sanitized['db_name']}`") + if self.create_user(user=user, password=password): - cursor.execute( - f"GRANT ALL PRIVILEGES ON {db_name}.* To '{user}'@'%'") + # Grant privileges using parameterized approach + grant_query = f"GRANT ALL PRIVILEGES ON `{sanitized['db_name']}`.* TO %s@'%'" + cursor.execute(grant_query, (sanitized['username'],)) + + connection.commit() return True - except self.Error as e: + except (self.Error, ValueError) as e: print(e) return False finally: - if not connection: - return False - if (connection.is_connected()): - cursor.close() + if connection and connection.is_connected(): + if cursor: + cursor.close() connection.close() - # create database user def create_user(self, user=None, password=None): + connection = None + cursor = None try: + # Sanitize inputs + sanitized = self.protector.sanitize_all_inputs(username=user, password=password) + connection = self.create_connection() if not connection: return False + cursor = connection.cursor() - cursor.execute( - f"CREATE USER '{user}'@'%' IDENTIFIED BY '{password}' ") + + # Use parameterized query for password, validated identifier for username + create_user_query = f"CREATE USER `{sanitized['username']}`@'%' IDENTIFIED BY %s" + cursor.execute(create_user_query, (sanitized['password'],)) + connection.commit() return True except self.Error as e: - if e.errno == '1396': + if e.errno == 1396: # User already exists return True + print(e) + return False + except ValueError as e: + print(f"Validation error: {e}") return False finally: - if not connection: - return False - if (connection.is_connected()): - cursor.close() + if connection and connection.is_connected(): + if cursor: + cursor.close() connection.close() def get_database_size(self, db_name=None, user=None, password=None): + connection = None + cursor = None try: + sanitized = self.protector.sanitize_all_inputs( + db_name=db_name, username=user, password=password + ) + connection = self.create_db_connection( db_name=db_name, user=user, password=password) if not connection: return 'N/A' + cursor = connection.cursor() - cursor.execute( - f"""SELECT table_schema "{db_name}", - SUM(data_length + index_length) / 1024 AS "Size(KB)" - FROM information_schema.TABLES - GROUP BY table_schema""") - db_size = '0' + + # Use parameterized query + size_query = """SELECT table_schema, + SUM(data_length + index_length) / 1024 AS "Size(KB)" + FROM information_schema.TABLES + WHERE table_schema = %s + GROUP BY table_schema""" + + cursor.execute(size_query, (sanitized['db_name'],)) + + db_size = '0 KB' for db in cursor: - db_size = f'{float(db[1])} KB' + db_size = f'{float(db[1]):.2f} KB' return db_size - except self.Error: + except (self.Error, ValueError): return 'N/A' finally: - if not connection: - return 'N/A' - if (connection.is_connected()): - cursor.close() + if connection and connection.is_connected(): + if cursor: + cursor.close() connection.close() - # reset password for database user def reset_password(self, user=None, password=None): + connection = None try: + sanitized = self.protector.sanitize_all_inputs(username=user, password=password) + connection = self.create_connection() if not connection: return False + cursor = connection.cursor() - cursor.execute( - f"ALTER USER '{user}'@'%' IDENTIFIED BY '{password}'") + + reset_query = f"ALTER USER `{sanitized['username']}`@'%' IDENTIFIED BY %s" + cursor.execute(reset_query, (sanitized['password'],)) + connection.commit() return True - except self.Error: + except (self.Error, ValueError) as e: + print(e) return False finally: - if not connection: - return False - if (connection.is_connected()): + if connection and connection.is_connected(): cursor.close() connection.close() - # delete database user def delete_user(self, user=None): + connection = None try: + sanitized = self.protector.sanitize_all_inputs(username=user) + connection = self.create_connection() if not connection: return False + cursor = connection.cursor() - cursor.execute(f"DROP USER '{user}' ") + + # Use validated identifier + cursor.execute(f"DROP USER `{sanitized['username']}`@'%'") + connection.commit() return True - except self.Error: + except (self.Error, ValueError) as e: + print(e) return False finally: - if not connection: - return False - if (connection.is_connected()): + if connection and connection.is_connected(): cursor.close() connection.close() - # delete database def delete_database(self, db_name): + connection = None try: + sanitized = self.protector.sanitize_all_inputs(db_name=db_name) + connection = self.create_connection() if not connection: return False + cursor = connection.cursor() - cursor.execute(f"DROP DATABASE {db_name}") - # TODO: Need to delete users too + cursor.execute(f"DROP DATABASE `{sanitized['db_name']}`") + connection.commit() return True - except self.Error: + except (self.Error, ValueError) as e: + print(e) return False finally: - if not connection: - return False - if (connection.is_connected()): + if connection and connection.is_connected(): cursor.close() connection.close() def reset_database(self, db_name=None, user=None, password=None): + connection = None + cursor = None try: + # Validate inputs first + sanitized = self.protector.sanitize_all_inputs( + db_name=db_name, username=user, password=password + ) + connection = self.create_connection() user_rights = self.check_user_db_rights( db_name=db_name, user=user, password=password) if not connection or not user_rights: return False + cursor = connection.cursor() - cursor.execute(f"DROP DATABASE {db_name}") + cursor.execute(f"DROP DATABASE `{sanitized['db_name']}`") + created_db = self.create_database( db_name=db_name, user=user, password=password) if not created_db: return False return True - except self.Error: + except (self.Error, ValueError) as e: + print(e) return False finally: - if not connection or not user_rights: - return False - if (connection.is_connected()): - cursor.close() + if connection and connection.is_connected(): + if cursor: + cursor.close() connection.close() - # Show all databases def get_all_databases(self): + connection = None try: connection = self.create_connection() if not connection: @@ -304,19 +444,20 @@ def get_all_databases(self): cursor.execute("SHOW DATABASES") database_list = [] for db in cursor: - database_list.append(db[0].decode()) + db_name = db[0] + if isinstance(db_name, bytes): + db_name = db_name.decode() + database_list.append(db_name) return database_list except self.Error: return False finally: - if not connection: - return False - if (connection.is_connected()): + if connection and connection.is_connected(): cursor.close() connection.close() - # Show users def get_all_users(self): + connection = None try: connection = self.create_connection() if not connection: @@ -325,25 +466,29 @@ def get_all_users(self): cursor.execute("SELECT user FROM mysql.user GROUP BY user") users_list = [] for db in cursor: - users_list.append(db[0].decode()) + user_name = db[0] + if isinstance(user_name, bytes): + user_name = user_name.decode() + users_list.append(user_name) return users_list except self.Error: return False finally: - if not connection: - return False - if (connection.is_connected()): + if connection and connection.is_connected(): cursor.close() connection.close() def get_server_status(self): + connection = None try: connection = self.create_connection() if not connection: - return False + return { + 'status': 'error', + 'message': 'Unable to connect to database' + } cursor = connection.cursor() cursor.execute("SHOW GLOBAL STATUS") - # cursor.fetchall() return { 'status': 'success', 'data': 'online' @@ -351,90 +496,112 @@ def get_server_status(self): except self.Error: return { 'status': 'error', - 'message': 'Error has occured'} - + 'message': 'Error has occurred' + } finally: - if not connection: - return { - 'status': 'error', - 'message': 'Unable to connect to database'} - if (connection.is_connected()): + if connection and connection.is_connected(): cursor.close() connection.close() def disable_user_access(self, db_name, db_user_name): + connection = None try: + sanitized = self.protector.sanitize_all_inputs( + db_name=db_name, username=db_user_name + ) + connection = self.create_connection() if not connection: return False cursor = connection.cursor() - cursor.execute(f"REVOKE ALL PRIVILEGES ON {db_name}.* FROM {db_user_name}") - - cursor.execute(f"GRANT SELECT, DELETE ON {db_name}.* TO {db_user_name}") + # Revoke all privileges first + revoke_query = f"REVOKE ALL PRIVILEGES ON `{sanitized['db_name']}`.* FROM `{sanitized['username']}`@'%'" + cursor.execute(revoke_query) + # Grant limited privileges + grant_query = f"GRANT SELECT, DELETE ON `{sanitized['db_name']}`.* TO `{sanitized['username']}`@'%'" + cursor.execute(grant_query) + + connection.commit() return True - except self.Error as e: + except (self.Error, ValueError) as e: print(e) return False finally: - if not connection: - return False - cursor.close() - connection.close() + if connection and connection.is_connected(): + cursor.close() + connection.close() def enable_user_write_access(self, db_name, db_user_name): + connection = None try: - connection = self.create_connection(db_name=db_name) + sanitized = self.protector.sanitize_all_inputs( + db_name=db_name, username=db_user_name + ) + + connection = self.create_connection() if not connection: return False cursor = connection.cursor() - cursor.execute(f"GRANT ALL PRIVILEGES ON {db_name}.* TO {db_user_name}") + + grant_query = f"GRANT ALL PRIVILEGES ON `{sanitized['db_name']}`.* TO `{sanitized['username']}`@'%'" + cursor.execute(grant_query) + connection.commit() return True - except self.Error as e: + except (self.Error, ValueError) as e: print(e) return False finally: - if not connection: - return False - cursor.close() - connection.close() + if connection and connection.is_connected(): + cursor.close() + connection.close() - # disable user database log in def disable_user_log_in(self, db_user_name, db_user_pw): + connection = None try: + sanitized = self.protector.sanitize_all_inputs( + username=db_user_name, password=db_user_pw + ) + connection = self.create_connection() if not connection: return False cursor = connection.cursor() - cursor.execute(f"ALTER USER {db_user_name} IDENTIFIED BY '{db_user_pw}'ACCOUNT LOCK") - + + lock_query = f"ALTER USER `{sanitized['username']}`@'%' IDENTIFIED BY %s ACCOUNT LOCK" + cursor.execute(lock_query, (sanitized['password'],)) + connection.commit() return True - except self.Error: + except (self.Error, ValueError) as e: + print(e) return False finally: - if not connection: - return False - if (connection.is_connected()): + if connection and connection.is_connected(): cursor.close() connection.close() - # enable user database log in def enable_user_log_in(self, db_user_name, db_user_pw): + connection = None try: + sanitized = self.protector.sanitize_all_inputs( + username=db_user_name, password=db_user_pw + ) + connection = self.create_connection() if not connection: return False cursor = connection.cursor() - cursor.execute(f"ALTER USER {db_user_name} IDENTIFIED BY '{db_user_pw}'ACCOUNT UNLOCK") - + + unlock_query = f"ALTER USER `{sanitized['username']}`@'%' IDENTIFIED BY %s ACCOUNT UNLOCK" + cursor.execute(unlock_query, (sanitized['password'],)) + connection.commit() return True - except self.Error: + except (self.Error, ValueError) as e: + print(e) return False finally: - if not connection: - return False - if (connection.is_connected()): + if connection and connection.is_connected(): cursor.close() connection.close() @@ -442,7 +609,7 @@ def enable_user_log_in(self, db_user_name, db_user_pw): class PostgresqlDbService(DatabaseService): def __init__(self): - super(DatabaseService, self).__init__() + super().__init__() self.Error = psycopg2.Error def create_connection(self): @@ -466,169 +633,230 @@ def check_db_connection(self): if not super_connection: return False return True - except self.Error as e: + except self.Error: return False finally: - if not super_connection: - return False - super_connection.close() + if super_connection: + super_connection.close() def create_db_connection(self, user=None, password=None, db_name=None): try: + # Sanitize inputs + sanitized = self.protector.sanitize_all_inputs( + db_name=db_name, username=user, password=password + ) + user_connection = psycopg2.connect( host=settings.ADMIN_PSQL_HOST, - user=user, - password=password, + user=sanitized['username'], + password=sanitized['password'], port=settings.ADMIN_PSQL_PORT, - database=db_name + database=sanitized['db_name'] ) return user_connection - except self.Error as e: + except (self.Error, ValueError) as e: print(e) return False def check_user_db_rights(self, user=None, password=None, db_name=None): - # TODO: Restrict users from accessing databases they dont own + user_connection = None try: user_connection = self.create_db_connection( user=user, password=password, db_name=db_name) if not user_connection: return False return True - except self.Error as e: + except (self.Error, ValueError) as e: print(e) return False finally: - if not user_connection: - return False - user_connection.close() + if user_connection: + user_connection.close() - # Create or check user exists database def create_database(self, db_name=None, user=None, password=None): + connection = None try: + # Sanitize all inputs + sanitized = self.protector.sanitize_all_inputs( + db_name=db_name, username=user, password=password + ) + connection = self.create_connection() if not connection: return False cursor = connection.cursor() + if self.create_user(user=user, password=password): + # Use sql.Identifier for safe identifier handling cursor.execute( - sql.SQL(f'CREATE DATABASE {db_name} WITH OWNER = {user}')) + sql.SQL('CREATE DATABASE {} WITH OWNER = {}').format( + sql.Identifier(sanitized['db_name']), + sql.Identifier(sanitized['username']) + ) + ) return True - except self.Error: + except (self.Error, ValueError) as e: + print(e) return False finally: - if not connection: - return False - cursor.close() - connection.close() + if connection: + cursor.close() + connection.close() - # create database user def create_user(self, user=None, password=None): + connection = None try: + # Sanitize inputs + sanitized = self.protector.sanitize_all_inputs(username=user, password=password) + connection = self.create_connection() if not connection: return False cursor = connection.cursor() + + # Use sql.Identifier and sql.Literal for safe query construction cursor.execute( - f"CREATE USER {user} WITH ENCRYPTED PASSWORD '{password}'") - connection.commit() + sql.SQL('CREATE USER {} WITH ENCRYPTED PASSWORD {}').format( + sql.Identifier(sanitized['username']), + sql.Literal(sanitized['password']) + ) + ) return True except self.Error as e: print(e) - if e.pgcode == '42710': + if e.pgcode == '42710': # Duplicate role return True return False + except ValueError as e: + print(f"Validation error: {e}") + return False finally: - if not connection: - return False - cursor.close() - connection.close() + if connection: + cursor.close() + connection.close() - # delete database user def delete_user(self, user=None): + connection = None try: + sanitized = self.protector.sanitize_all_inputs(username=user) + connection = self.create_connection() if not connection: return False cursor = connection.cursor() - cursor.execute(f"DROP USER {user} ") - connection.commit() + + cursor.execute( + sql.SQL('DROP USER {}').format( + sql.Identifier(sanitized['username']) + ) + ) return True except self.Error as e: - if e.pgcode == '42704': + if e.pgcode == '42704': # Undefined object return True + print(e) + return False + except ValueError as e: + print(f"Validation error: {e}") return False finally: - if not connection: - return False - cursor.close() - connection.close() + if connection: + cursor.close() + connection.close() - # delete database def delete_database(self, db_name): + connection = None try: + sanitized = self.protector.sanitize_all_inputs(db_name=db_name) + connection = self.create_connection() if not connection: return False cursor = connection.cursor() - cursor.execute(f"DROP DATABASE {db_name}") - # TODO: Need to delete users too + + cursor.execute( + sql.SQL('DROP DATABASE {}').format( + sql.Identifier(sanitized['db_name']) + ) + ) return True - except self.Error as e: + except (self.Error, ValueError) as e: print(e) return False finally: - if not connection: - return False - cursor.close() - connection.close() + if connection: + cursor.close() + connection.close() def reset_database(self, db_name=None, user=None, password=None): + connection = None + cursor = None try: + # Validate inputs first + sanitized = self.protector.sanitize_all_inputs( + db_name=db_name, username=user, password=password + ) + connection = self.create_connection() user_rights = self.check_user_db_rights( db_name=db_name, user=user, password=password) if not connection or not user_rights: return False + cursor = connection.cursor() - cursor.execute(f"DROP DATABASE {db_name}") + cursor.execute( + sql.SQL('DROP DATABASE {}').format( + sql.Identifier(sanitized['db_name']) + ) + ) + created_db = self.create_database( db_name=db_name, user=user, password=password) if not created_db: return False return True - except self.Error: + except (self.Error, ValueError) as e: + print(e) return False finally: - if not connection or not user_rights: - return False - cursor.close() - connection.close() + if connection: + if cursor: + cursor.close() + connection.close() def get_database_size(self, db_name=None, user=None, password=None): + connection = None try: + sanitized = self.protector.sanitize_all_inputs( + db_name=db_name, username=user, password=password + ) + connection = self.create_db_connection( db_name=db_name, user=user, password=password) if not connection: return 'N/A' cursor = connection.cursor() + + # Use parameterized query cursor.execute( - f"""SELECT pg_size_pretty( pg_database_size('{db_name}') )""") - db_size = 0 + 'SELECT pg_size_pretty(pg_database_size(%s))', + (sanitized['db_name'],) + ) + + db_size = '0 bytes' for db in cursor: db_size = db[0] return db_size - except self.Error: + except (self.Error, ValueError): return 'N/A' finally: - if not connection: - return 'N/A' - cursor.close() - connection.close() + if connection: + cursor.close() + connection.close() - # Show all databases def get_all_databases(self): + connection = None try: connection = self.create_connection() if not connection: @@ -642,20 +870,18 @@ def get_all_databases(self): except self.Error: return False finally: - if not connection: - return False - cursor.close() - connection.close() + if connection: + cursor.close() + connection.close() - # Show users def get_all_users(self): + connection = None try: connection = self.create_connection() if not connection: return False cursor = connection.cursor() - cursor.execute( - "SELECT usename FROM pg_catalog.pg_user") + cursor.execute("SELECT usename FROM pg_catalog.pg_user") users_list = [] for db in cursor: users_list.append(db[0]) @@ -663,35 +889,44 @@ def get_all_users(self): except self.Error: return False finally: - if not connection: - return False - cursor.close() - connection.close() + if connection: + cursor.close() + connection.close() - # reset database user password def reset_password(self, user=None, password=None): + connection = None try: + sanitized = self.protector.sanitize_all_inputs(username=user, password=password) + connection = self.create_connection() if not connection: return False cursor = connection.cursor() + cursor.execute( - f"ALTER USER {user} WITH ENCRYPTED PASSWORD '{password}'") - connection.commit() + sql.SQL('ALTER USER {} WITH ENCRYPTED PASSWORD {}').format( + sql.Identifier(sanitized['username']), + sql.Literal(sanitized['password']) + ) + ) return True - except self.Error: + except (self.Error, ValueError) as e: + print(e) return False finally: - if not connection: - return False - cursor.close() - connection.close() + if connection: + cursor.close() + connection.close() def get_server_status(self): + connection = None try: connection = self.create_connection() if not connection: - return False + return { + 'status': 'error', + 'message': 'Unable to connect to database' + } cursor = connection.cursor() cursor.execute("SELECT pg_is_in_recovery()") @@ -699,97 +934,140 @@ def get_server_status(self): if db[0]: return { 'status': 'failed', - 'message': 'in recovery'} + 'message': 'in recovery' + } else: return { 'status': 'success', - 'message': 'online'} + 'message': 'online' + } except self.Error: return { 'status': 'error', - 'message': 'Error has occured'} + 'message': 'Error has occurred' + } finally: - if not connection: - return { - 'status': 'error', - 'message': 'Unable to connect to database'} - - cursor.close() - connection.close() + if connection: + cursor.close() + connection.close() def disable_user_access(self, db_name, db_user_name): """Grants read and delete access to the specified user, revoking write and update privileges.""" + connection = None try: + sanitized = self.protector.sanitize_all_inputs( + db_name=db_name, username=db_user_name + ) + connection = self.create_connection() if not connection: return False cursor = connection.cursor() - cursor.execute(f"REVOKE INSERT, UPDATE ON DATABASE {db_name} FROM {db_user_name}") + + # Use sql.Identifier for safe identifier handling cursor.execute( - f"REVOKE INSERT, UPDATE ON ALL TABLES IN SCHEMA public FROM {db_user_name}") + sql.SQL('REVOKE INSERT, UPDATE ON DATABASE {} FROM {}').format( + sql.Identifier(sanitized['db_name']), + sql.Identifier(sanitized['username']) + ) + ) + cursor.execute( + sql.SQL('REVOKE INSERT, UPDATE ON ALL TABLES IN SCHEMA public FROM {}').format( + sql.Identifier(sanitized['username']) + ) + ) cursor.execute( - f"REVOKE USAGE ON SCHEMA public FROM {db_user_name}") + sql.SQL('REVOKE USAGE ON SCHEMA public FROM {}').format( + sql.Identifier(sanitized['username']) + ) + ) return True - except self.Error as e: + except (self.Error, ValueError) as e: print(e) return False finally: - if not connection: - return False - cursor.close() - connection.close() + if connection: + cursor.close() + connection.close() def enable_user_write_access(self, db_name, db_user_name): + connection = None try: - connection = self.create_connection(db_name=db_name) + sanitized = self.protector.sanitize_all_inputs( + db_name=db_name, username=db_user_name + ) + + connection = self.create_connection() if not connection: return False cursor = connection.cursor() - cursor.execute(f"GRANT INSERT, UPDATE ON DATABASE {db_name} TO {db_user_name}") + + cursor.execute( + sql.SQL('GRANT INSERT, UPDATE ON DATABASE {} TO {}').format( + sql.Identifier(sanitized['db_name']), + sql.Identifier(sanitized['username']) + ) + ) + cursor.execute( + sql.SQL('GRANT INSERT, UPDATE ON ALL TABLES IN SCHEMA public TO {}').format( + sql.Identifier(sanitized['username']) + ) + ) cursor.execute( - f"GRANT INSERT, UPDATE ON ALL TABLES IN SCHEMA public TO {db_user_name}") - cursor.execute(f"GRANT USAGE ON SCHEMA public TO {db_user_name}") + sql.SQL('GRANT USAGE ON SCHEMA public TO {}').format( + sql.Identifier(sanitized['username']) + ) + ) return True - except self.Error as e: + except (self.Error, ValueError) as e: print(e) return False finally: - if not connection: - return False - cursor.close() - connection.close() - - # disable user database log in + if connection: + cursor.close() + connection.close() def disable_user_log_in(self, db_user_name): + connection = None try: + sanitized = self.protector.sanitize_all_inputs(username=db_user_name) + connection = self.create_connection() if not connection: return False cursor = connection.cursor() - cursor.execute(f"ALTER USER {db_user_name} NOLOGIN") - + + cursor.execute( + sql.SQL('ALTER USER {} NOLOGIN').format( + sql.Identifier(sanitized['username']) + ) + ) return True - except self.Error as e: + except (self.Error, ValueError) as e: print(e) return False finally: - if not connection: - return False - cursor.close() - connection.close() + if connection: + cursor.close() + connection.close() - # enable user database log in def enable_user_log_in(self, db_user_name): + connection = None try: + sanitized = self.protector.sanitize_all_inputs(username=db_user_name) + connection = self.create_connection() if not connection: return False cursor = connection.cursor() - cursor.execute(f"ALTER USER {db_user_name} WITH LOGIN") - + + cursor.execute( + sql.SQL('ALTER USER {} WITH LOGIN').format( + sql.Identifier(sanitized['username']) + ) + ) return True - except self.Error as e: + except (self.Error, ValueError) as e: print(e) return False finally: diff --git a/app/routes.py b/app/routes.py index 95cbc4d..12d2e9d 100644 --- a/app/routes.py +++ b/app/routes.py @@ -520,7 +520,8 @@ def allocate_storage(database_id: str, additional_storage: int, access_token: st if not database_connection: return failed_database_connection(current_user, "ADD STORAGE") - + + # TODO: Implement the logic to allocate additional storage database.allocated_size_kb += additional_storage save_to_database(db) return {"message": f"Additional {additional_storage} bytes of storage allocated to the database", "status_code": 200} From fbf69564724dcc267ab09b4d4c09d9e78baa3502 Mon Sep 17 00:00:00 2001 From: khalifan-kfan Date: Mon, 11 Aug 2025 20:16:26 +0300 Subject: [PATCH 2/2] return todo comment --- app/helpers/database_service.py | 1 + 1 file changed, 1 insertion(+) diff --git a/app/helpers/database_service.py b/app/helpers/database_service.py index 609d005..1e4cdf4 100644 --- a/app/helpers/database_service.py +++ b/app/helpers/database_service.py @@ -391,6 +391,7 @@ def delete_database(self, db_name): cursor = connection.cursor() cursor.execute(f"DROP DATABASE `{sanitized['db_name']}`") + # TODO: Need to delete users too connection.commit() return True except (self.Error, ValueError) as e: