Untitled

 avatar
unknown
python
8 months ago
7.2 kB
3
Indexable
import os
import re
import csv
from pathlib import Path

def analyze_sql_files(root_path):
    """
    Recursively scan folders for SQL files and analyze table references.
    
    Args:
        root_path (str): Root directory path to start scanning
    """
    # Lists to store table names
    source_tables = set()
    created_tables = set()
    
    # Updated regular expressions for matching SQL patterns
    create_pattern = re.compile(
        r'create\s+(?:or\s+replace\s+)?'  # CREATE or CREATE OR REPLACE
        r'(?:temp|temporary\s+)?'          # Optional TEMPORARY keyword
        r'table\s+'                        # TABLE keyword
        r'([^\s\(]+)',                     # Table name
        re.IGNORECASE
    )
    
    from_pattern = re.compile(r'from\s+([^\s\,\(\)]+)', re.IGNORECASE)
    join_pattern = re.compile(r'join\s+([^\s\,\(\)]+)', re.IGNORECASE)
    
    insert_pattern = re.compile(
        r'insert\s+'                       # INSERT
        r'(?:overwrite\s+)?'              # Optional OVERWRITE keyword
        r'(?:into\s+)?'                   # Optional INTO keyword
        r'([^\s\(]+)',                    # Table name
        re.IGNORECASE
    )
    
    def clean_table_name(table_name):
        """
        Clean and format Snowflake table names.
        Preserves database.schema.table_name format if present.
        """
        # Remove quotes if present
        table_name = table_name.strip('"\'`')
        # Remove any whitespace
        table_name = table_name.strip()
        # Return the full name without modification of the parts
        return table_name
    
    def split_commands(sql_content):
        """
        Split SQL content into individual commands based on semicolons.
        Handles cases where semicolons might appear within quotes or comments.
        """
        commands = []
        current_command = []
        in_quote = False
        quote_char = None
        in_comment = False
        in_multiline_comment = False
        i = 0
        
        while i < len(sql_content):
            char = sql_content[i]
            
            # Handle quotes
            if char in ["'", '"', '`'] and not in_comment and not in_multiline_comment:
                if not in_quote:
                    in_quote = True
                    quote_char = char
                elif quote_char == char:
                    in_quote = False
            
            # Handle multi-line comments
            elif char == '/' and i + 1 < len(sql_content) and sql_content[i + 1] == '*' and not in_quote:
                in_multiline_comment = True
                i += 1
            elif char == '*' and i + 1 < len(sql_content) and sql_content[i + 1] == '/' and in_multiline_comment:
                in_multiline_comment = False
                i += 1
                
            # Handle single-line comments
            elif char == '-' and i + 1 < len(sql_content) and sql_content[i + 1] == '-' and not in_quote:
                in_comment = True
            elif char == '\n' and in_comment:
                in_comment = False
                
            # Handle semicolon
            elif char == ';' and not in_quote and not in_comment and not in_multiline_comment:
                current_command.append(char)
                command = ''.join(current_command).strip()
                if command:  # Only add non-empty commands
                    commands.append(command)
                current_command = []
            else:
                current_command.append(char)
                
            i += 1
            
        # Add the last command if there is one
        last_command = ''.join(current_command).strip()
        if last_command:
            commands.append(last_command)
            
        return commands

    # Recursively find all .sql files
    for sql_file in Path(root_path).rglob('*.sql'):
        try:
            print(f"Processing file: {sql_file}")  # Added feedback for file processing
            with open(sql_file, 'r', encoding='utf-8') as f:
                content = f.read()
                commands = split_commands(content)
                
                for command in commands:
                    # Find all CREATE TABLE statements (including OR REPLACE and TEMPORARY)
                    creates = create_pattern.findall(command)
                    for table in creates:
                        clean_table = clean_table_name(table)
                        created_tables.add(clean_table)
                        print(f"Found created table: {clean_table}")  # Added feedback
                    
                    # Find all source tables from FROM clauses
                    sources = from_pattern.findall(command)
                    for table in sources:
                        if not any(keyword in table.lower() for keyword in ['dual', 'sysdate', 'table(', 'lateral']):
                            clean_table = clean_table_name(table)
                            source_tables.add(clean_table)
                            print(f"Found source table (FROM): {clean_table}")  # Added feedback
                    
                    # Find all source tables from JOIN clauses
                    joins = join_pattern.findall(command)
                    for table in joins:
                        if not any(keyword in table.lower() for keyword in ['lateral']):
                            clean_table = clean_table_name(table)
                            source_tables.add(clean_table)
                            print(f"Found source table (JOIN): {clean_table}")  # Added feedback
                    
                    # Find tables being inserted into (including OVERWRITE)
                    inserts = insert_pattern.findall(command)
                    for table in inserts:
                        clean_table = clean_table_name(table)
                        created_tables.add(clean_table)
                        print(f"Found insert table: {clean_table}")  # Added feedback
                    
        except Exception as e:
            print(f"Error processing {sql_file}: {str(e)}")
    
    # Remove created tables from source tables if they appear in both
    source_tables = source_tables - created_tables
    
    # Write results to CSV files
    with open('source_tables.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Source Tables'])
        writer.writerows([[table] for table in sorted(source_tables)])
    
    with open('created_tables.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Created Tables'])
        writer.writerows([[table] for table in sorted(created_tables)])
    
    return source_tables, created_tables

if __name__ == "__main__":
    import sys
    
    if len(sys.argv) != 2:
        print("Usage: python script.py <root_path>")
        sys.exit(1)
    
    root_path = sys.argv[1]
    source_tables, created_tables = analyze_sql_files(root_path)
    
    print(f"\nFound {len(source_tables)} source tables and {len(created_tables)} created tables.")
    print("Results have been saved to 'source_tables.csv' and 'created_tables.csv'")
Editor is loading...
Leave a Comment