Untitled

mail@pastecode.io avatar
unknown
plain_text
7 months ago
8.7 kB
4
Indexable
Never
package config

import (
	"fmt"
	"github.com/go-playground/validator/v10"
	"github.com/sauron-platform/network-monitor/pkg/defaults"
	"github.com/spf13/viper"
	"net"
	"os"
	"reflect"
	"regexp"
	ctrl "sigs.k8s.io/controller-runtime"
	"strconv"
	"strings"
	"time"
)

var (
	log = ctrl.Log.WithName("config")
)

// Configuration options
const (
	optionHostname                = "hostname"
	optionDeviceType              = "deviceType"
	optionGrpcUntrusted           = "grpcUntrusted"
	optionGrpcServerPort          = "grpcServerPort"
	optionLinkServicePort         = "linkServicePort"
	optionE2EServicePort          = "e2eServicePort"
	optionNeighServicePort        = "neighServicePort"
	optionOtelExportEndpoint      = "otelConfig.exportEndpoint"
	optionOtelExportInterval      = "otelConfig.exportInterval"
	optionLinkTestDscpValues      = "linkTestConfig.dscpValues"
	optionLinkTestPayloadLen      = "linkTestConfig.payloadLen"
	optionLinkTestMeasureInterval = "linkTestConfig.measureInterval"
)

var (
	options = [...]string{
		optionHostname,
		optionDeviceType,
		optionGrpcUntrusted,
		optionGrpcServerPort,
		optionLinkServicePort,
		optionE2EServicePort,
		optionNeighServicePort,
		optionOtelExportEndpoint,
		optionOtelExportInterval,
		optionLinkTestDscpValues,
		optionLinkTestPayloadLen,
		optionLinkTestMeasureInterval,
	}
)

// Notice: mapstructure tag names must match option definitions.

type OtelConfig struct {
	ExportEndpoint string        `mapstructure:"exportEndpoint" validate:"isdefault|endpoint_port"`
	ExportInterval time.Duration `mapstructure:"exportInterval" validate:"gte=5m"`
}

type LinkTestConfig struct {
	DscpValues      []uint8       `mapstructure:"dscpValues" validate:"required,dive,gte=0,lte=63"`
	PayloadLen      uint16        `mapstructure:"payloadLen" validate:"gte=0,lte=8956"`
	MeasureInterval time.Duration `mapstructure:"measureInterval" validate:"gte=5s"`
}

type Configuration struct {
	Hostname         string         `mapstructure:"hostname" validate=:"required"`
	DeviceType       string         `mapstructure:"deviceType" validate:"oneof=vCU NGDU"`
	GrpcUntrusted    bool           `mapstructure:"grpcUntrusted" validate:"-"`
	GrpcServerPort   uint16         `mapstructure:"grpcServerPort" validate:"gte=1,lte=65535"`
	LinkServicePort  uint16         `mapstructure:"linkServicePort" validate:"gte=1,lte=65535"`
	E2EServicePort   uint16         `mapstructure:"e2eServicePort" validate:"gte=1,lte=65535"`
	NeighServicePort uint16         `mapstructure:"neighServicePort" validate:"gte=1,lte=65535"`
	OtelConfig       OtelConfig     `mapstructure:"otelConfig"`
	LinkTestConfig   LinkTestConfig `mapstructure:"linkTestConfig"`
}

var Config = &Configuration{}

// Populate populates the fields of the Configuration object taking values from configuration file or environment
// variables.
func (c *Configuration) Populate() error {
	var err error

	// Retrieve and set up a new viper instance for loading and parsing
	v := viper.GetViper()
	if err = setupViper(v); err != nil {
		return fmt.Errorf("cannot setup viper: %v", err)
	}

	// Load configuration from configuration file. If the configuration file is not present, fallback to environment
	// variables
	if err = v.ReadInConfig(); err != nil {
		if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
			return fmt.Errorf("cannot read config file: %v", err)
		}
		log.V(1).Info("Cannot find configuration file. Fallback to environment variable")
	}

	// Set hostname option to machine hostname if it is not provided by the user
	if !v.IsSet(optionHostname) {
		log.V(1).Info("Hostname option not provided. Attempt to get machine hostname")
		// Get hostname of the machine
		var hostname string
		if hostname, err = os.Hostname(); err != nil {
			return fmt.Errorf("cannot retrieve machine hostname: %v", err)
		}
		log.V(1).Info("Hostname defaulted to machine hostname", "hostname", hostname)
		v.SetDefault(optionHostname, hostname)
	}

	// Parse configuration
	if err = v.Unmarshal(c); err != nil {
		return fmt.Errorf("cannot parse retrieved configuration: %v", err)
	}

	// Register custom validations and validate configuration
	validate := validator.New()
	if err = registerValidations(validate); err != nil {
		return err
	}
	if err = validate.Struct(c); err != nil {
		return fmt.Errorf("failed validation: %v", err)
	}
	return nil
}

const (
	validationTagEndpointPort = "endpoint_port"
)

// registerValidations registers custom validations
func registerValidations(validate *validator.Validate) error {
	if err := validate.RegisterValidation(validationTagEndpointPort, validateEndpointPort); err != nil {
		return fmt.Errorf("cannot register validatation: %v", err)
	}
	return nil
}

// Regex copied from https://github.com/go-playground/validator repo.
const hostnameRegexStringRFC1123 = `^([a-zA-Z0-9]{1}[a-zA-Z0-9-]{0,62}){1}(\.[a-zA-Z0-9]{1}[a-zA-Z0-9-]{0,62})*?$`

var hostnameRegexRFC1123 = regexp.MustCompile(hostnameRegexStringRFC1123)

// validateEndpointPort flags a field as valid if it is the form [host]:port. If host is present, it must be a valid
// hostname (RFC1123) or a valid IP address. Valid IP addresses can be IPv4 or IPv6. IPv6 addresses must be specified
// inside square brackets.
func validateEndpointPort(fl validator.FieldLevel) bool {
	field := fl.Field()
	if field.Kind() != reflect.String {
		return false
	}

	return isEndpointPort(field.String())
}

func isEndpointPort(val string) bool {
	host, port, err := net.SplitHostPort(val)
	if err != nil {
		return false
	}
	// Verify if port is a number in the valid range
	if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 1 {
		return false
	}

	// If host is specified, it should match a DNS name or being an IP address
	if host != "" {
		return hostnameRegexRFC1123.MatchString(host) || net.ParseIP(host) != nil
	}

	return true
}

// Default viper configuration parameters.
const (
	viperConfigName = "config"
	viperConfigType = "yaml"
	viperConfigPath = "/config/"
)

// setupViper initializes the viper instance configuration, sets up the options default values and binds each option to
// the corresponding environment variable name.
func setupViper(v *viper.Viper) error {
	var err error

	// Initialize viper configuration
	v.SetConfigName(viperConfigName)
	v.SetConfigType(viperConfigType)
	v.AddConfigPath(viperConfigPath)

	// Set options default values
	setOptionsDefaults(v)

	// Bind each option to the corresponding environment variable name
	if err = bindEnvs(v); err != nil {
		return fmt.Errorf("error during environment variable binding: %v", err)
	}

	//v.OnConfigChange(func(e fsnotify.Event) {
	//	// TODO:
	//	fmt.Println("Config file changed:", e.Name)
	//})
	//v.WatchConfig()

	return nil
}

// setOptionsDefaults sets, for each option, the corresponding default value.
func setOptionsDefaults(v *viper.Viper) {
	// Default for hostname option is not set here since we don't want to retrieve the os hostname in advance. Moreover,
	// retrieval of the hostname could produce an error.
	//v.SetDefault(optionHostname, ...)
	v.SetDefault(optionDeviceType, defaults.DeviceType)
	v.SetDefault(optionGrpcUntrusted, defaults.GrpcUntrusted)
	v.SetDefault(optionGrpcServerPort, defaults.GrpcServerPort)
	v.SetDefault(optionLinkServicePort, defaults.LinkServicePort)
	v.SetDefault(optionE2EServicePort, defaults.E2EServicePort)
	v.SetDefault(optionNeighServicePort, defaults.NeighServicePort)
	// no default for optionOtelExportEndpoint
	v.SetDefault(optionOtelExportInterval, defaults.OtelExportInterval)
	v.SetDefault(optionLinkTestDscpValues, defaults.LinkTestDscpValues())
	v.SetDefault(optionLinkTestPayloadLen, defaults.LinkTestPayloadLen)
	v.SetDefault(optionLinkTestMeasureInterval, defaults.LinkTestMeasureInterval)
}

// bindEnvs binds each option to the corresponding environment variable name. The environment variable is obtained by
// converting each option name to upper snake case. Each option name is assumed to be in mixed case.
func bindEnvs(v *viper.Viper) error {
	var err error
	for _, option := range options {
		if err = v.BindEnv(option, mixedCaseToUpperSnakeCase(option)); err != nil {
			return err
		}
	}

	return nil
}

var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")

// mixedCaseToUpperSnakeCase converts mixed case string into upper snake case string
func mixedCaseToUpperSnakeCase(s string) string {
	snake := matchFirstCap.ReplaceAllString(s, "${1}_${2}")
	snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
	snake = strings.ReplaceAll(strings.ToUpper(snake), ".", "_")
	return snake
}