webdav-sync

This Go snippet implements a file synchronization tool that watches for changes in a local directory and uploads modified files to a WebDAV server.
 avatar
unknown
golang
16 days ago
12 kB
1
Indexable
package main

import (
	"crypto/sha256"
	"flag"
	"fmt"
	"io"
	"log"
	"net/url"
	"os"
	"path/filepath"
	"strconv"
	"strings"
	"sync"
	"time"

	"github.com/fsnotify/fsnotify"
	"github.com/schollz/progressbar/v3"
	"github.com/studio-b12/gowebdav"
)

const (
	defaultSyncInterval = 300
	defaultLogLevel     = "info"
	defaultMaxLogDirs   = 10
	defaultDebounce     = 500
)

var (
	watcher       *fsnotify.Watcher
	dav           *gowebdav.Client
	localPath     string
	serverBase    string
	serverPath    string
	user          string
	password      string
	syncMutex     sync.Mutex
	periodicCheck *time.Ticker
	fileStates    = make(map[string]fileState)
	stateMutex    sync.Mutex
	debounceTime  time.Duration
)

type fileState struct {
	lastUpload time.Time
	hash       string
}

func main() {
	flagDebounce := flag.Int("debounce", defaultDebounce, "Debounce time in milliseconds (env: DEBOUNCE)")
	flagSyncInterval := flag.Int("sync_interval", defaultSyncInterval, "Sync interval in seconds (env: SYNC_INTERVAL)")
	flagLogLevel := flag.String("log_level", defaultLogLevel, "Logging level (debug, info, warn, error)")
	flagMaxLogDirs := flag.Int("max_log_dirs", defaultMaxLogDirs, "Maximum directories to log details")
	flagLocalPath := flag.String("local_path", "", "Local directory path (env: WEBDAV_LOCAL_PATH)")
	flagURI := flag.String("uri", "", "WebDAV server URI (env: WEBDAV_URI)")
	flagUser := flag.String("user", "", "WebDAV username (env: WEBDAV_USER)")
	flagPassword := flag.String("password", "", "WebDAV password (env: WEBDAV_PASSWORD)")
	flag.Parse()

	debounce := getEnvIntOrDefault(*flagDebounce, "DEBOUNCE", defaultDebounce)
	debounceTime = time.Duration(debounce) * time.Millisecond

	syncInterval := getEnvIntOrDefault(*flagSyncInterval, "SYNC_INTERVAL", defaultSyncInterval)
	logLevel := strings.ToLower(getEnvOrDefault(*flagLogLevel, "LOG_LEVEL"))
	maxLogDirs := getEnvIntOrDefault(*flagMaxLogDirs, "MAX_LOG_DIRS", defaultMaxLogDirs)
	localPath = getEnvOrDefault(*flagLocalPath, "WEBDAV_LOCAL_PATH")
	uri := getEnvOrDefault(*flagURI, "WEBDAV_URI")
	user = getEnvOrDefault(*flagUser, "WEBDAV_USER")
	password = getEnvOrDefault(*flagPassword, "WEBDAV_PASSWORD")

	initLogger(logLevel)

	if localPath == "" || uri == "" || user == "" || password == "" {
		log.Fatal("Missing required parameters")
	}

	parsedURI, err := url.Parse(uri)
	if err != nil {
		log.Fatalf("Invalid URI: %v", err)
	}
	serverBase = parsedURI.Scheme + "://" + parsedURI.Host
	serverPath = filepath.Clean(parsedURI.Path)

	dav = gowebdav.NewClient(serverBase, user, password)
	dav.SetTimeout(30 * time.Second)

	logInfo("Starting initial sync from server...")
	if err := syncFromServer(serverPath, localPath); err != nil {
		log.Fatalf("Initial sync failed: %v", err)
	}
	logInfo("Initial sync completed")

	watcher, err = fsnotify.NewWatcher()
	if err != nil {
		log.Fatalf("Watcher init failed: %v", err)
	}
	defer watcher.Close()

	logInfo("Initializing watchers...")
	if err := addWatchers(localPath, maxLogDirs); err != nil {
		log.Fatalf("Failed to initialize watchers: %v", err)
	}

	periodicCheck = time.NewTicker(time.Duration(syncInterval) * time.Second)
	defer periodicCheck.Stop()
	go periodicSync()

	for {
		select {
		case event, ok := <-watcher.Events:
			if !ok {
				return
			}
			logDebug("FS EVENT: %v", event)
			processEvent(event)
		case err, ok := <-watcher.Errors:
			if !ok {
				return
			}
			logError("Watcher error: %v", err)
		}
	}
}

func getEnvIntOrDefault(flagValue int, envVar string, defaultValue int) int {
	if flagValue != 0 {
		return flagValue
	}
	if envVal := os.Getenv(envVar); envVal != "" {
		if val, err := strconv.Atoi(envVal); err == nil {
			return val
		}
	}
	return defaultValue
}

func initLogger(level string) {
	log.SetFlags(log.LstdFlags | log.Lshortfile)
	switch strings.ToLower(level) {
	case "debug":
		log.SetOutput(os.Stdout)
	case "warn":
		log.SetOutput(io.Discard)
	case "error":
		log.SetOutput(os.Stderr)
	default:
		log.SetOutput(os.Stdout)
	}
}

func logInfo(format string, v ...interface{}) {
	log.Printf("[INFO] "+format, v...)
}

func logDebug(format string, v ...interface{}) {
	log.Printf("[DEBUG] "+format, v...)
}

func logError(format string, v ...interface{}) {
	log.Printf("[ERROR] "+format, v...)
}

func getEnvOrDefault(flagValue, envVar string) string {
	if flagValue != "" {
		return flagValue
	}
	return os.Getenv(envVar)
}

func syncFromServer(remotePath, localPath string) error {
	items, err := dav.ReadDir(remotePath)
	if err != nil {
		return err
	}

	logInfo("Syncing %d items...", len(items))
	bar := progressbar.Default(int64(len(items)))

	for _, item := range items {
		remoteFull := filepath.Join(remotePath, item.Name())
		localFull := filepath.Join(localPath, item.Name())

		if item.IsDir() {
			if err := os.MkdirAll(localFull, 0755); err != nil {
				return err
			}
			if err := syncFromServer(remoteFull, localFull); err != nil {
				return err
			}
		} else {
			if needsDownload(localFull, item.Size()) {
				if err := downloadFile(remoteFull, localFull); err != nil {
					return err
				}
			}
		}
		bar.Add(1)
	}
	return nil
}

func needsDownload(localPath string, remoteSize int64) bool {
	info, err := os.Stat(localPath)
	if os.IsNotExist(err) {
		return true
	}
	if err != nil || info.Size() != remoteSize {
		return true
	}
	return false
}

func downloadFile(remotePath, localPath string) error {
	reader, err := dav.ReadStream(remotePath)
	if err != nil {
		return err
	}
	defer reader.Close()

	file, err := os.Create(localPath)
	if err != nil {
		return err
	}
	defer file.Close()

	_, err = io.Copy(file, reader)
	return err
}

func addWatchers(root string, maxLogDirs int) error {
	var counter int
	totalDirs := 0

	err := filepath.Walk(root, func(path string, fi os.FileInfo, err error) error {
		if err != nil {
			return err
		}
		if fi.IsDir() {
			totalDirs++
		}
		return nil
	})

	if err != nil {
		return err
	}

	return filepath.Walk(root, func(path string, fi os.FileInfo, err error) error {
		if err != nil {
			return err
		}
		if fi.IsDir() {
			if err := watcher.Add(path); err != nil {
				return fmt.Errorf("failed to add watcher for %s: %v", path, err)
			}
			counter++

			if counter <= maxLogDirs || counter == maxLogDirs+1 {
				logDebug("Watching directory: %s", path)
				if counter == maxLogDirs+1 {
					logDebug("...and %d more directories", totalDirs-maxLogDirs)
				}
			}
		}
		return nil
	})
}



func processEvent(event fsnotify.Event) {
    syncMutex.Lock()
    defer syncMutex.Unlock()

    time.Sleep(debounceTime / 2)

    if shouldIgnoreEvent(event) {
        return
    }

    go cleanFileStates(24 * time.Hour)

    stateMutex.Lock()
    state, exists := fileStates[event.Name]
    stateMutex.Unlock()

    if exists && time.Since(state.lastUpload) < debounceTime {
        logDebug("Debounce triggered for: %s", event.Name)
        return
    }

    for i := 0; i < 3; i++ {
        if err := tryProcessEvent(event); err == nil {
            break
        }
        time.Sleep(time.Second * time.Duration(i+1))
    }
}

func handleWrite(localPath, webdavPath string) error {
    if isDir(localPath) {
        logDebug("Skipping directory write: %s", localPath)
        return nil
    }

    currentHash, err := fileHash(localPath)
    if err != nil {
        logError("Hash error: %v", err)
        return err
    }

    stateMutex.Lock()
    defer stateMutex.Unlock()

    if state, exists := fileStates[localPath]; exists {
        if state.hash == currentHash {
            logDebug("File unchanged: %s", localPath)
            return nil
        }
    }

    logDebug("Starting upload: %s", localPath)
    file, err := os.Open(localPath)
    if err != nil {
        logError("Open error: %v", err)
        return err
    }
    defer file.Close()

    if err := dav.WriteStream(webdavPath, file, 0644); err != nil {
        logError("Upload failed: %v", err)
        return err
    }

    fileStates[localPath] = fileState{
        lastUpload: time.Now(),
        hash:       currentHash,
    }
    
    logInfo("Successfully uploaded: %s", webdavPath)
    return nil
}

func fileHash(path string) (string, error) {
    file, err := os.Open(path)
    if err != nil {
        return "", err
    }
    defer file.Close()

    hasher := sha256.New()
    if _, err := io.Copy(hasher, file); err != nil {
        return "", err
    }
    
    return fmt.Sprintf("%x", hasher.Sum(nil)), nil
}

func tryProcessEvent(event fsnotify.Event) error {
	relPath, err := filepath.Rel(localPath, event.Name)
	if err != nil {
		logError("Path error: %v", err)
		return err
	}

	webdavPath := filepath.Join(serverPath, filepath.ToSlash(relPath))

	logDebug("Processing: %s %s", event.Op, webdavPath)

	switch {
	case event.Op&fsnotify.Write == fsnotify.Write:
		return handleWrite(event.Name, webdavPath)
	case event.Op&fsnotify.Create == fsnotify.Create:
		return handleCreate(event.Name, webdavPath)
	case event.Op&fsnotify.Remove == fsnotify.Remove,
		event.Op&fsnotify.Rename == fsnotify.Rename:
		return handleRemove(webdavPath)
	}
	return nil
}

func shouldIgnoreEvent(event fsnotify.Event) bool {
	if event.Op == fsnotify.Create && isDir(event.Name) {
		return false
	}
	if isDir(event.Name) {
		return true
	}
	if event.Op&fsnotify.Chmod != 0 {
		return true
	}
	if strings.HasSuffix(event.Name, "~") ||
		strings.HasPrefix(event.Name, ".") ||
		strings.HasSuffix(event.Name, ".tmp") {
		return true
	}
	return false
}

func handleCreate(localPath, webdavPath string) error {
	if isDir(localPath) {
		if err := dav.MkdirAll(webdavPath, 0755); err != nil {
			logError("Mkdir error: %v", err)
			return err
		}

		if err := filepath.Walk(localPath, func(path string, info os.FileInfo, err error) error {
			if info.IsDir() {
				if err := watcher.Add(path); err != nil {
					logError("Failed to watch %s: %v", path, err)
				}
			}
			return nil
		}); err != nil {
			logError("Failed to add watchers: %v", err)
		}

		logInfo("Created directory: %s", webdavPath)

		go func() {
			syncMutex.Lock()
			defer syncMutex.Unlock()
			if err := syncToServer(localPath, webdavPath); err != nil {
				logError("Sync failed: %v", err)
			}
		}()
	} else {
		return handleWrite(localPath, webdavPath)
	}
	return nil
}

func syncToServer(localDir, webdavDir string) error {
	return filepath.Walk(localDir, func(path string, info os.FileInfo, err error) error {
		if err != nil {
			return err
		}

		relPath, _ := filepath.Rel(localDir, path)
		remotePath := filepath.Join(webdavDir, relPath)

		if info.IsDir() {
			return dav.MkdirAll(remotePath, 0755)
		}

		return uploadFile(path, remotePath)
	})
}

func uploadFile(localPath, remotePath string) error {
	file, err := os.Open(localPath)
	if err != nil {
		return err
	}
	defer file.Close()

	return dav.WriteStream(remotePath, file, 0644)
}



func handleRemove(webdavPath string) error {
	if err := dav.Remove(webdavPath); err != nil {
		logError("Delete error: %v", err)
		return err
	}

	stateMutex.Lock()
	for k := range fileStates {
		if strings.HasPrefix(k, webdavPath) {
			delete(fileStates, k)
		}
	}
	stateMutex.Unlock()

	logInfo("Deleted: %s", webdavPath)
	return nil
}



func cleanFileStates(maxAge time.Duration) {
	stateMutex.Lock()
	defer stateMutex.Unlock()

	now := time.Now()
	for path, state := range fileStates {
		if now.Sub(state.lastUpload) > maxAge {
			delete(fileStates, path)
		}
	}
}


func isDir(path string) bool {
	info, err := os.Stat(path)
	return err == nil && info.IsDir()
}

func periodicSync() {
	for range periodicCheck.C {
		syncMutex.Lock()
		logInfo("Starting periodic server sync...")
		if err := syncFromServer(serverPath, localPath); err != nil {
			logError("Periodic sync failed: %v", err)
		} else {
			logInfo("Periodic sync completed")
		}
		syncMutex.Unlock()
	}
}
Leave a Comment