Untitled
# Old path: # pytorch tensor -> shard -> distribute to devices 1. ttnn.from_torch with mesh mapper supplied 2. from torch used mesh mapper to split the tensor 3. Each individual tensor was assembled into multi device tensor # New path: # pytorch tensor -> ttnn -> distribute to devices 1. ttnn.from_torch 2. Create a mesh sharding function via exposed shard_tensor_to_2d_mesh_mapper(...); 3. Shard ttnn tensor created in (1) using sharding function created in (2) # implementation could supply sharding function as mesh mapper in from_torch: new_mesh_mapper = ttnn.shard_tensor_to_2d_mesh_mapper(..) ttnn.from_torch(,... mapper = new_mesh_mapper) * Same applies to concat, just in reverse
Leave a Comment