Untitled

mail@pastecode.io avatar
unknown
plain_text
a month ago
17 kB
2
Indexable
Never
#include "judge.h"
#include <algorithm>
#include <cassert>
#include <cmath>
#include <set>
#include <vector>
#include <memory>
#include <type_traits>
#include <optional>
using namespace std;

#define NODE_ASSERT(v) \
  assert(v != nullptr); \
  assert(v->l <= v->r); \
  assert(v->l >= 1) 

#define UPPER_LOWER_ASSERT(x) \
  static_assert(x, "Cannot use this method in case when " #x " template parameter is equal to false. " \
                   "This type of operation works only for value indexed trees.")

namespace peki
{

template <typename T, typename V>
struct SegmentTreeValueToIndex {
  static_assert(false, "Specialize struct for types T, V you used to instantiate tree.");
};

template <typename T>
struct SegmentTreeValueToIndex<T, T> {
  T operator()(const T& value) const {
    return value;
  }
};

template <typename T, typename V>
struct SegmentTreeValueToIndex<T, std::pair<T, V>>
{
  T operator()(const std::pair<V, T>& value) const {
    return value.first;
  }
};

template <typename T, typename V, T N, bool value_indexed = false, typename Cmp = std::less<V>>
class SegmentTree {

  template <typename VerifyT, typename = void>
  struct CmpVerify : std::false_type 
  {};

  template <typename VerifyT>
  struct CmpVerify<VerifyT, std::void_t<std::enable_if_t<
    std::is_same_v<bool, decltype(std::declval<VerifyT>()(std::declval<V>(), std::declval<V>()))>>>> : std::true_type
  {};

  static_assert(std::is_integral_v<T>, "Tree is indexed by positive integral numbers.");
  static_assert(N > 0, "Tree should have at least one element.");
  static_assert(CmpVerify<Cmp>::value, "\nComparator is not proper function object type.\n"
                                       "Need to have overloaded operator () with signature:\n"
                                       "bool operator()(const V&, const V&).");
  static_assert(std::is_default_constructible_v<Cmp>, "Comparator should be default constructible.");

  const inline static Cmp userCmp = Cmp();
  struct Compare {
    bool operator()(const std::pair<V, uint32_t>& a, const std::pair<V, uint32_t>& b) const {
      const auto& [Avalue, Aindex] = a;
      const auto& [Bvalue, Bindex] = b;
      if (!userCmp(Avalue, Bvalue) && !userCmp(Bvalue, Avalue)) {
        return Aindex < Bindex;
      }
      return userCmp(Avalue, Bvalue);
    }
  };

  struct Min {
    const V& operator()(const V& a, const V& b) const {
      if(userCmp(a, b)) return a;
      return b;
    }
  };

  struct Max {
    const V& operator()(const V& a, const V& b) const {
      if(userCmp(b, a)) return a;
      return b;
    }
  };

  struct Node {
    Node() 
      : l(0), r(0), count(0), parent(nullptr)
    {}
    Node(T l_, T r_)
      : l(l_), r(r_), count(0), parent(nullptr)
    {}

    Node(T l_, T r_, Node* parent_)
      : l(l_), r(r_), count(0), parent(parent_)
    {}

    T l;
    T r;
    T count;
    std::optional<V> min;
    std::optional<V> max;
    std::set<std::pair<V, uint32_t>, Compare> values;
    Node *parent;
    std::unique_ptr<Node> left;
    std::unique_ptr<Node> right;
  };

  private:
    enum class Operation {
      INSERT,
      ERASE
    };

    static constexpr T getTreeSize(T n) {
      T size = 1;
      while(size < n) {
        size *= 2;
      }
      return size;
    }

    bool removeVerticeIfUnused(Node* v) {
      if(v->count > 0) return false;
      if(v == root.get()) {
        v->count = 0;
        v->max = v->min = std::nullopt;
        v->values.clear();
        return true;
      }
      assert(v->parent);
      v->parent->left.get() == v 
        ? v->parent->left = nullptr 
        : v->parent->right = nullptr;
      return true;
    }

    void updateNode(Node* v) 
    {
      auto calculateMinMax = [](const auto& av_opt, const auto& bv_opt, auto&& operation) -> std::optional<V> 
      {
        if(av_opt && bv_opt)       return operation(*av_opt, *bv_opt);
        else if(av_opt && !bv_opt) return *av_opt;
        else if(!av_opt && bv_opt) return *bv_opt;
        else                       return std::nullopt;
      };

      Node dummyNode;
      Node *l_v{v->left.get() ? v->left.get() : &dummyNode};
      Node *r_v{v->right.get() ? v->right.get() : &dummyNode};
      v->count = l_v->count + r_v->count;
      if(!removeVerticeIfUnused(v)) {
        v->min = calculateMinMax(l_v->min, r_v->min, minOp);
        v->max = calculateMinMax(l_v->max, r_v->max, maxOp);
      }
    }

    void updateLeaf(Node* v, const V& element, Operation op)
    {
      auto updateMinMax = [v]() {
        auto it_end = v->values.end();
        v->min = (*v->values.begin()).first;
        v->max = (*(--it_end)).first;
      };

      switch(op)
      {
        case Operation::INSERT: {
          v->count += 1; 
          v->values.emplace(element, next_element_id++);
          updateMinMax();
          break;
        }
        case Operation::ERASE: {
          auto it = v->values.lower_bound(std::make_pair(element, static_cast<uint32_t>(0)));
          if(it != v->values.end() && !(it->first < element) && !(element < it->first)) {
            v->values.erase(it);
            v->count -= 1;
          }
          if(!removeVerticeIfUnused(v)) {
            updateMinMax();
          }
          break;
        }
        default:
          assert(!"Invalid operation!");
      }
    }

    void updateImpl(Node* v, T index, const V& element, Operation op)
    {
      NODE_ASSERT(v);
      assert(index >= v->l && index <= v->r);
      if(v->l == v->r) {
        assert(v->l == index);
        updateLeaf(v, element, op);
        return;
      }
      T mid = (v->l + v->r) / 2;
      T l1 = v->l, r1 = mid;
      T l2 = mid + 1, r2 = v->r;
      switch(op)
      {
        case Operation::INSERT: {
          if (index <= r1) {
            if(!v->left) v->left = std::make_unique<Node>(l1, r1, v);
            updateImpl(v->left.get(), index, element, op);
          } else {
            if(!v->right) v->right = std::make_unique<Node>(l2, r2, v);
            updateImpl(v->right.get(), index, element, op);
          }
          break;
        }
        case Operation::ERASE: {
          if (index <= r1) {
            if(v->left) updateImpl(v->left.get(), index, element, op);
          } else {
            if(v->right) updateImpl(v->right.get(), index, element, op);
          }
          break;
        }
        default:
          assert(!"Invalid operation!");
      }
      updateNode(v);
    }

    uint32_t countElementsImpl(Node* v, T a, T b) const {
      assert(a <= b);
      if(v == nullptr) return 0;
      else if(v->r < a || v->l > b) return 0;
      else if(v->l >= a && v->r <= b) return v->count;
      else return countElementsImpl(v->left.get(), a, b) + countElementsImpl(v->right.get(), a, b);
    }

    template<typename F>
    std::optional<std::reference_wrapper<const V>> minMaxImpl(Node *v, T a, T b, F&& f) const
    {
      assert(a <= b);
      if(v == nullptr) return {};
      else if(v->r < a || v->l > b) return {};
      else if(v->l >= a && v->r <= b) return f(*v->min, *v->max);
      else {
        const auto& left = minMaxImpl(v->left.get(), a, b, f);
        const auto& right = minMaxImpl(v->right.get(), a, b, f);
        if(!left) return right;
        if(!right) return left;
        else return f(left->get(), right->get());
      }
    }

    T valueToIndex(const V& value) const {
      static SegmentTreeValueToIndex<T, V> valToIndex;
      return valToIndex(value);
    }

    std::optional<std::reference_wrapper<const V>> lowerEqualImpl(Node* v, T x) const {
      if(v == nullptr) return {};
      if(v->l == v->r) return *v->max; // here we found exactly x, *v->max = x for this case
      T mid = (v->l + v->r) / 2;
      if(x < v->l) return {};
      else if(x > v->r) return *v->max;
      else if(x >= v->l && x <= mid) return lowerEqualImpl(v->left.get(), x);
      else {
        const auto& result = lowerEqualImpl(v->right.get(), x);
        if(result) return result;
        else {
          if(v->left) return *v->left->max;
          return {};
        }
      }
    }

    std::optional<std::reference_wrapper<const V>> upperEqualImpl(Node* v, T x) const {
      if(v == nullptr) return {};
      if(v->l == v->r) return *v->min; // here we found exactly x, *v->min = x for this case
      T mid = (v->l + v->r) / 2;
      if(x < v->l) return *v->min;
      else if(x > v->r) return {};
      else if(x >= mid + 1 && x <= v->r) return upperEqualImpl(v->right.get(), x);
      else {
        const auto& result = upperEqualImpl(v->left.get(), x);
        if(result) return result;
        else {
          if(v->right) return *v->right->min;
          return {};
        }
      }
    }

    const T size;
    std::unique_ptr<Node> root;
    uint32_t next_element_id{0};
    const Min minOp{};
    const Max maxOp{};
  public:
    SegmentTree()
      : size(getTreeSize(N)),
        root(new Node(1, size))
    { }

    void insert(T index, V element) {
      updateImpl(root.get(), index, element, Operation::INSERT);
    }

    void erase(T index, V element) {
      updateImpl(root.get(), index, element, Operation::ERASE);
    }

    uint32_t countElements(T a, T b) const {
      return countElementsImpl(root.get(), a, b);
    }

    std::optional<std::reference_wrapper<const V>> min(T a, T b) const{
      return minMaxImpl(root.get(), a, b, minOp);
    }

    std::optional<std::reference_wrapper<const V>> max(T a, T b) const {
      return minMaxImpl(root.get(), a, b, maxOp);
    }

    std::optional<std::reference_wrapper<const V>> lowerEqual(const V& x) const {
      UPPER_LOWER_ASSERT(value_indexed);
      return lowerEqualImpl(root.get(), valueToIndex(x));
    }

    std::optional<std::reference_wrapper<const V>> upper(const V& x) const {
      UPPER_LOWER_ASSERT(value_indexed);
      return upperEqualImpl(root.get(), valueToIndex(x) + 1);
    }
};
}// end of peki namespace

constexpr int N = 5e4 + 7;
constexpr int T = 50 + 7;

namespace {
	pair<bool, int> castToIntWithEpsilon(double v) {
		constexpr double eps = 1e-5;
		if(static_cast<int>(v + eps) > static_cast<int>(v)){
			return {true, static_cast<int>(v) + 1};
		} else if(static_cast<int>(v - eps) < static_cast<int>(v)){
			return {true, static_cast<int>(v)};
		}
		return {false, static_cast<int>(v)};
	}
}

namespace logStructure {
	using TreeT = peki::SegmentTree<int, std::pair<int, int>, N, true>;
	using PointsT = std::set<std::pair<int, int>>;
	std::unique_ptr<TreeT> segT[T][N];
	std::unique_ptr<PointsT> points[T][N];

	void insert(int type, int x, int y, int id) {
		if(!segT[type][y]) {
			segT[type][y] = std::make_unique<TreeT>();
			points[type][y] = std::make_unique<PointsT>();
		}
		segT[type][y]->insert(x, {x, -id});
		points[type][y]->emplace(x, id);
	}

	void remove(int type, int x, int y, int id) {
		if(!segT[type][y]) return;
		segT[type][y]->erase(x, {x, -id});
		points[type][y]->erase({x, id});
		if(points[type][y]->size() == 0) {
			segT[type][y] = nullptr;
			points[type][y] = nullptr;
		}
	}

	int count(int type, int y, int a, int b) {
		if(!segT[type][y]) return 0;
		return segT[type][y]->countElements(a, b);
	}

	int lowerEqual(int type, int y, int val, int prev_id) {
		if(!segT[type][y]) return -1;

		auto get_lower_equal = [](const TreeT& tree, int value, int previous_id) -> int {
			auto result = tree.lowerEqual({value, previous_id});
			if(result) 
				return -result->get().second;
			return -1;
		};

		if(prev_id == -1) {
			return get_lower_equal(*segT[type][y], val, prev_id);
		}
		else {
			auto it = points[type][y]->upper_bound({val, prev_id});
			if(it != points[type][y]->end() && it->first == val){
				return it->second;
			}
			if(val - 1 == 0) return -1;
			return get_lower_equal(*segT[type][y], val - 1, prev_id);
		}
	}

	int upper(int type, int y, int val, int prev_id) {
		if(!segT[type][y]) return -1;
		auto it = points[type][y]->upper_bound({val, prev_id});
		if(it != points[type][y]->end()) {
			return it->second;
		}
		return -1;
	}
}

struct Point {
	int x;
	int y;
	int id;
	int type;
	Point(int _x, int _y, int _id, int _type) : x(_x), y(_y), id(_id), type(_type) {

	}
	Point() = default;
};

class circle : public icircle {
public:
	circle() {
		points.resize(N);
	}
	void addGemstone(int x, int y, int id, int type) override;
	void removeGemstone(int id) override;
	int countGemstones(int x, int y, int r, int type) override;
	int findClosestGemstones(int x, int y, int r, int type, int res[3]) override;
	~circle();
private:
	long long dist(const Point &p1, const Point &p2);
	bool checkDist(const Point &p1, const Point &p2, int r);

  std::pair<int, int> getInterval(int x, int y, int ya, int r);
  std::vector<std::pair<long long, int>> getCandidates(int type, int x, int y, int ya);

  std::vector<Point> points;
};

long long circle::dist(const Point &p1, const Point &p2) {
	return static_cast<long long>(p1.x - p2.x) * (p1.x - p2.x) + static_cast<long long>(p1.y - p2.y) * (p1.y - p2.y);
}

bool circle::checkDist(const Point &p1, const Point &p2, int r) {
	auto distance = dist(p1, p2);
	return distance <= static_cast<long long>(r) * r;
}

pair <int, int> circle::getInterval(int x, int y, int ya, int r)
{
	const long long a = 1;
	const long long b = static_cast<long long>(-2) * x;
	const long long c = static_cast<long long>(x) * x
											- static_cast<long long>(r) * r
											+ static_cast<long long>(y) * y
											+ static_cast<long long>(ya) * ya
											- static_cast<long long>(2) * y * ya;

	double delta = b * b - 4 * a * c;
	// for given internal should always exists solution
	assert(delta >= 0);
	double sqrt_delta = sqrt(delta);
	
	auto x1 = static_cast<double>(-b - sqrt_delta) / 2 * a;
	auto x2 = static_cast<double>(-b + sqrt_delta) / 2 * a;
	
	if(x1 < 0) { x1 = 0.0; }
	if(x2 < 0) { x2 = 0.0; }
	if(x1 > x2) swap(x1, x2);

	auto [perfectCast, A] = castToIntWithEpsilon(x1);
	if(!perfectCast) A = static_cast<int>(x1 + 1.0);
	return {A, castToIntWithEpsilon(x2).second};
}

void circle::addGemstone(int x, int y, int id, int type)
{
	x++, y++;
	points[id]= {x, y, id, type};
	logStructure::insert(0, x, y, id);
	logStructure::insert(type, x, y, id);
}

void circle::removeGemstone(int id)
{
	const auto& point = points[id];
	logStructure::remove(0, point.x, point.y, point.id);
	logStructure::remove(point.type, point.x, point.y, point.id);
}

int circle::countGemstones(int x, int y, int r, int type) 
{
	x++; y++;
	int ya = std::max(y - r, 1);
	int yb = std::min(y + r, N);
	int sum = 0;
	for(int i = ya; i <= yb; i++){
		const auto& [a, b] = getInterval(x, y, i, r);
		sum += logStructure::count(type, i, a, b);
	}
	return sum;
}

vector<pair<long long, int>> circle::getCandidates(int type, int x, int y, int ya)
{
	auto compareLeftRight = [this, x, y](int leftID, int rightID) {
		if(leftID == -1) return 1;
		if(rightID == -1) return 0;

		const auto leftDist = dist(points[leftID], {x, y, -1, -1});
		const auto rightDist = dist(points[rightID], {x, y, -1, -1});
		
		if(leftDist < rightDist) return 0;
		else if(leftDist == rightDist) {
			if(leftID < rightID) return 0;
			return 1;
		}
		return 1;
	};

	vector<pair<long long, int>> resultCandidates;
	int leftID = logStructure::lowerEqual(type, ya, x, -1);
	int rightID = logStructure::upper(type, ya, x, N);
	int count = 0;
	constexpr int c_max = 3;
	while((leftID > -1 || rightID > -1) && count++ < c_max) {
		int comparisonResult = compareLeftRight(leftID, rightID);
		if(comparisonResult == 0) {
			resultCandidates.emplace_back(dist(points[leftID], {x, y, -1, -1}), leftID);
			leftID = logStructure::lowerEqual(type, ya, points[leftID].x, leftID);
		} else if(comparisonResult == 1) {
			resultCandidates.emplace_back(dist(points[rightID], {x, y, -1, -1}), rightID);
			rightID = logStructure::upper(type, ya, points[rightID].x, rightID);
		}
	}
	return resultCandidates;
}

int circle::findClosestGemstones(int x, int y, int r, int type, int res[3])
{
	x++; y++;
	int ya = std::max(y - r, 1);
	int yb = std::min(y + r, N);

	vector <pair<long long, int>> candidates;
	for(int i = ya; i <= yb; i++) {
		auto newCandidates = getCandidates(type, x, y, i);
		std::copy(newCandidates.begin(), newCandidates.end(), std::back_inserter(candidates));
	}

	std::sort(candidates.begin(), candidates.end());
	constexpr std::size_t c_max = 3;
	std::size_t count = 0;
	for(;count < std::min(candidates.size(), c_max)
							 && candidates[count].first <= static_cast<long long>(r) * r;
			count++)
	{
		res[count] = candidates[count].second;
	}
	return count;
}

circle::~circle() {
	for(int i = 0; i < T; i++) {
		for(int j = 0; j < N; j++) {
			logStructure::segT[i][j].reset();
			logStructure::points[i][j].reset();
		}
	}
}

int main() {
	circle c;
	judge::run(&c);
	return 0l;
}
Leave a Comment