Untitled
unknown
plain_text
10 months ago
4.4 kB
20
Indexable
class MeshDeviceFixtureBase : public ::testing::Test {
protected:
using MeshDevice = ::tt::tt_metal::distributed::MeshDevice;
using MeshDeviceConfig = ::tt::tt_metal::distributed::MeshDeviceConfig;
using MeshShape = ::tt::tt_metal::distributed::MeshShape;
enum class MeshDeviceType {
N300,
T3000,
};
struct Config {
// If unset, the mesh device type will be deduced automatically based on the connected devices.
std::optional<MeshDeviceType> mesh_device_type;
int num_cqs = 1;
};
MeshDeviceFixtureBase(const Config& fixture_config) : config_(fixture_config) {}
void SetUp() override {
auto slow_dispatch = getenv("TT_METAL_SLOW_DISPATCH_MODE");
if (slow_dispatch) {
GTEST_SKIP() << "Skipping Mesh-Device test suite, since it can only be run in Fast Dispatch Mode.";
}
const auto arch = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name());
if (arch != tt::ARCH::WORMHOLE_B0) {
GTEST_SKIP() << "Skipping MeshDevice test suite on a non-wormhole machine.";
}
const auto num_devices = tt::tt_metal::GetNumAvailableDevices();
const auto mesh_device_type = derive_mesh_device_type(num_devices);
if (!mesh_device_type) {
GTEST_SKIP() << fmt::format(
"Skipping MeshDevice test suite on a machine with an unsupported number of devices {}.",
num_devices);
}
if (config_.mesh_device_type.has_value() && *config_.mesh_device_type != *mesh_device_type) {
GTEST_SKIP() << fmt::format(
"Skipping MeshDevice test suite on a {} machine that does not match the configured mesh device type {}",
magic_enum::enum_name(*mesh_device_type),
magic_enum::enum_name(*config_.mesh_device_type));
}
// Use ethernet dispatch for more than 1 CQ on T3K/N300
DispatchCoreType core_type = (config_.num_cqs >= 2) ? DispatchCoreType::ETH : DispatchCoreType::WORKER;
mesh_device_ = MeshDevice::create(
MeshDeviceConfig{.mesh_shape = get_mesh_shape(*mesh_device_type)}, 0, 0, config_.num_cqs, core_type);
}
void TearDown() override {
if (!mesh_device_) {
return;
}
mesh_device_->close();
mesh_device_.reset();
}
std::shared_ptr<tt::tt_metal::distributed::MeshDevice> mesh_device_;
private:
// Returns the mesh shape for a given mesh device type.
MeshShape get_mesh_shape(MeshDeviceType mesh_device_type) {
switch (mesh_device_type) {
case MeshDeviceType::N300: return MeshShape(2, 1);
case MeshDeviceType::T3000: return MeshShape(2, 4);
}
}
// Determines the mesh device type based on the number of devices.
std::optional<MeshDeviceType> derive_mesh_device_type(size_t num_devices) {
switch (num_devices) {
case 2: return MeshDeviceType::N300;
case 8: return MeshDeviceType::T3000;
}
return std::nullopt;
}
Config config_;
};
// Fixtures that determine the mesh device type automatically.
class MeshDeviceFixture : public MeshDeviceFixtureBase {
protected:
MeshDeviceFixture() : MeshDeviceFixtureBase(Config{.num_cqs = 1}) {}
};
class MultiCQMeshDeviceFixture : public MeshDeviceFixtureBase {
protected:
MultiCQMeshDeviceFixture() : MeshDeviceFixtureBase(Config{.num_cqs = 2}) {}
};
// Fixtures that specify the mesh device type explicitly.
class N300MeshDeviceFixture : public MeshDeviceFixtureBase {
protected:
N300MeshDeviceFixture() : MeshDeviceFixtureBase(Config{.mesh_device_type = MeshDeviceType::N300}) {}
};
class T3000MeshDeviceFixture : public MeshDeviceFixtureBase {
protected:
T3000MeshDeviceFixture() : MeshDeviceFixtureBase(Config{.mesh_device_type = MeshDeviceType::T3000}) {}
};
class N300MultiCQMeshDeviceFixture : public MeshDeviceFixtureBase {
protected:
N300MultiCQMeshDeviceFixture() :
MeshDeviceFixtureBase(Config{.mesh_device_type = MeshDeviceType::N300, .num_cqs = 2}) {}
};
class T3000MultiCQMeshDeviceFixture : public MeshDeviceFixtureBase {
protected:
T3000MultiCQMeshDeviceFixture() :
MeshDeviceFixtureBase(Config{.mesh_device_type = MeshDeviceType::T3000, .num_cqs = 2}) {}
};
Editor is loading...
Leave a Comment