Untitled

 avatar
unknown
plain_text
2 months ago
4.4 kB
14
Indexable

class MeshDeviceFixtureBase : public ::testing::Test {
protected:
    using MeshDevice = ::tt::tt_metal::distributed::MeshDevice;
    using MeshDeviceConfig = ::tt::tt_metal::distributed::MeshDeviceConfig;
    using MeshShape = ::tt::tt_metal::distributed::MeshShape;

    enum class MeshDeviceType {
        N300,
        T3000,
    };
    
    struct Config {
        // If unset, the mesh device type will be deduced automatically based on the connected devices.
        std::optional<MeshDeviceType> mesh_device_type;
        int num_cqs = 1;
    };


    MeshDeviceFixtureBase(const Config& fixture_config) : config_(fixture_config) {}

    void SetUp() override {
        auto slow_dispatch = getenv("TT_METAL_SLOW_DISPATCH_MODE");
        if (slow_dispatch) {
            GTEST_SKIP() << "Skipping Mesh-Device test suite, since it can only be run in Fast Dispatch Mode.";
        }

        const auto arch = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name());
        if (arch != tt::ARCH::WORMHOLE_B0) {
            GTEST_SKIP() << "Skipping MeshDevice test suite on a non-wormhole machine.";
        }

        const auto num_devices = tt::tt_metal::GetNumAvailableDevices();
        const auto mesh_device_type = derive_mesh_device_type(num_devices);
        if (!mesh_device_type) {
            GTEST_SKIP() << fmt::format(
                "Skipping MeshDevice test suite on a machine with an unsupported number of devices {}.",
                num_devices);
        }

        if (config_.mesh_device_type.has_value() && *config_.mesh_device_type != *mesh_device_type) {
            GTEST_SKIP() << fmt::format(
                "Skipping MeshDevice test suite on a {} machine that does not match the configured mesh device type {}",
                magic_enum::enum_name(*mesh_device_type),
                magic_enum::enum_name(*config_.mesh_device_type));
        }

        // Use ethernet dispatch for more than 1 CQ on T3K/N300
        DispatchCoreType core_type = (config_.num_cqs >= 2) ? DispatchCoreType::ETH : DispatchCoreType::WORKER;
        mesh_device_ = MeshDevice::create(
            MeshDeviceConfig{.mesh_shape = get_mesh_shape(*mesh_device_type)}, 0, 0, config_.num_cqs, core_type);
    }

    void TearDown() override {
        if (!mesh_device_) {
            return;
        }
        mesh_device_->close();
        mesh_device_.reset();
    }

    std::shared_ptr<tt::tt_metal::distributed::MeshDevice> mesh_device_;

private:
    // Returns the mesh shape for a given mesh device type.
    MeshShape get_mesh_shape(MeshDeviceType mesh_device_type) {
        switch (mesh_device_type) {
            case MeshDeviceType::N300: return MeshShape(2, 1);
            case MeshDeviceType::T3000: return MeshShape(2, 4);
        }
    }

    // Determines the mesh device type based on the number of devices.
    std::optional<MeshDeviceType> derive_mesh_device_type(size_t num_devices) {
        switch (num_devices) {
            case 2: return MeshDeviceType::N300;
            case 8: return MeshDeviceType::T3000;
        }
        return std::nullopt;
    }

    Config config_;
};

// Fixtures that determine the mesh device type automatically.
class MeshDeviceFixture : public MeshDeviceFixtureBase {
protected:
    MeshDeviceFixture() : MeshDeviceFixtureBase(Config{.num_cqs = 1}) {}
};

class MultiCQMeshDeviceFixture : public MeshDeviceFixtureBase {
protected:
    MultiCQMeshDeviceFixture() : MeshDeviceFixtureBase(Config{.num_cqs = 2}) {}
};

// Fixtures that specify the mesh device type explicitly.
class N300MeshDeviceFixture : public MeshDeviceFixtureBase {
protected:
    N300MeshDeviceFixture() : MeshDeviceFixtureBase(Config{.mesh_device_type = MeshDeviceType::N300}) {}
};

class T3000MeshDeviceFixture : public MeshDeviceFixtureBase {
protected:
    T3000MeshDeviceFixture() : MeshDeviceFixtureBase(Config{.mesh_device_type = MeshDeviceType::T3000}) {}
};

class N300MultiCQMeshDeviceFixture : public MeshDeviceFixtureBase {
protected:
    N300MultiCQMeshDeviceFixture() :
        MeshDeviceFixtureBase(Config{.mesh_device_type = MeshDeviceType::N300, .num_cqs = 2}) {}
};

class T3000MultiCQMeshDeviceFixture : public MeshDeviceFixtureBase {
protected:
    T3000MultiCQMeshDeviceFixture() :
        MeshDeviceFixtureBase(Config{.mesh_device_type = MeshDeviceType::T3000, .num_cqs = 2}) {}
};
Editor is loading...
Leave a Comment