Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
2.0 kB
9
Indexable
Never
class LloydMax:
    __slots__ = 'restitution_lvls', 'bounds', 'num_levels', 'partitions', 'distribution'

    def __init__(self, num_levels, bounds, distribution, restitution_lvl=None):

        self.num_levels = num_levels
        self.bounds = bounds
        self.distribution = distribution
        self.partitions = None

        if restitution_lvl is None:
            self.restitution_lvls = np.sort(np.random.uniform(bounds[0], bounds[1], num_levels))
        else:
            if len(restitution_lvl) != num_levels:
                raise ValueError('Non coherent input')
            else:
                self.restitution_lvls = restitution_lvl

    def _make_partitions(self):
        partitions = np.zeros(self.num_levels - 1, dtype=np.float32)
        for i in range(self.num_levels - 1):
            partitions[i] = .5 * (self.restitution_lvls[i] + self.restitution_lvls[i + 1])
        self.partitions = partitions

    def _get_restitution_lvls(self, realizations=10000):
        new_restitution_lvls = list()
        samples = self.distribution(size=realizations)
        intervals = self._get_intervals(self.partitions)
        for x1, x2 in intervals:
            filtered_samples = samples[np.logical_and(samples > x1, samples < x2)]
            new_restitution_lvls.append(np.mean(filtered_samples))
        return new_restitution_lvls

    def compute(self, norm, threshold):
        difference = float('inf')
        while difference > threshold:
            self._make_partitions()
            new_rls = self._get_restitution_lvls()
            difference = norm(self.restitution_lvls, new_rls)
            self.restitution_lvls = new_rls

    @staticmethod
    def _get_intervals(partitions):
        partitions = np.concatenate(([float('-inf')], partitions, [float('inf')]))
        intervals = list()
        for i in range(len(partitions) - 1):
            intervals.append((partitions[i], partitions[i + 1]))
        return intervals