Untitled

 avatar
unknown
plain_text
8 days ago
691 B
8
Indexable
# Old path:
# pytorch tensor -> shard -> distribute to devices
1. ttnn.from_torch with mesh mapper supplied
2. from torch used mesh mapper to split the tensor
3. Each individual tensor was assembled into multi device tensor

# New path:
# pytorch tensor -> ttnn -> distribute to devices
1. ttnn.from_torch
2. Create a mesh sharding function via exposed shard_tensor_to_2d_mesh_mapper(...);
3. Shard ttnn tensor created in (1) using sharding function created in (2)

# implementation could supply sharding function as mesh mapper in from_torch:
new_mesh_mapper = ttnn.shard_tensor_to_2d_mesh_mapper(..)
ttnn.from_torch(,... mapper = new_mesh_mapper)

* Same applies to concat, just in reverse
Leave a Comment