Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
3.8 kB
1
Indexable
Never
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from mpl_toolkits.mplot3d import Axes3D
import cv2
from numpy.linalg import inv
import re

def demo(G, imgpath):
    net = G
    G.eval()
    #net.load_state_dict(torch.load('/kaggle/input/weight1062/weight (5).pt'))

    img = cv2.imread(imgpath)
    transform = transforms.ToTensor()
    img = transform(img).cuda().view(1, 3, 256, 256)
    #img = img/255*2 - 1

    output = net(img)
    print(output.shape)

    strands = output[0].cpu().detach().numpy() 

    gaussian = cv2.getGaussianKernel(10, 3)
    for i in range(strands.shape[2]):
        for j in range(strands.shape[3]):
            strands[:, :3, i, j] = cv2.filter2D(strands[:, :3, i, j], -1, gaussian)

    show3DhairPlotByStrands(strands)


def example(convdata):
    strands = np.load(convdata).reshape(100, 4, 32, 32)
    show3DhairPlotByStrands(strands)

def set_axes_equal(ax: plt.Axes):
    """Set 3D plot axes to equal scale.

    Make axes of 3D plot have equal scale so that spheres appear as
    spheres and cubes as cubes.  Required since `ax.axis('equal')`
    and `ax.set_aspect('equal')` don't work on 3D.
    """
    limits = np.array([
        ax.get_xlim3d(),
        ax.get_ylim3d(),
        ax.get_zlim3d(),
    ])
    origin = np.mean(limits, axis=1)
    radius = 0.5 * np.max(np.abs(limits[:, 1] - limits[:, 0]))
    _set_axes_radius(ax, origin, radius)
    
def _set_axes_radius(ax, origin, radius):
    x, y, z = origin
    ax.set_xlim3d([x - radius, x + radius])
    ax.set_ylim3d([y - radius, y + radius])
    ax.set_zlim3d([z - radius, z + radius])
    
def Wave(t, A=0.01, B=0.01, alpha=1, C=0.01, beta=1, D=0.01, t0=2, R=5, P=1, theta=0, Bias=0):
    mag = A + B*t*np.exp(-alpha*t) + C*(1 - np.exp(-beta*t)) + D*np.exp(t - t0)
    wave = mag * np.sin(2*np.pi*(R*t + P)*t + theta) + Bias
    return wave
    
def show3DhairPlotByStrands(strands):
    """
    strands: [100, 4, 32, 32]
    mask: [32, 32] bool
    """

    fig = plt.figure(figsize=(40, 40))
    ax = fig.add_subplot(111, projection="3d")

    avgx, avgy, avgz = 0, 0, 0
    
    count = 0
    total_dist_list = []
    total_c_list = []
    for i in range(32):
        for j in range(32):
            if sum(sum(strands[:, :, i, j])) == 0:
                continue
            strand = strands[:, :, i, j]
            x0, y0, z0, _ = strand[0]
            # each strand now has shape (100, 3)
            x = strand[:, 0]
            y = strand[:, 1]
            z = strand[:, 2]
            c = strand[:, 3]
            d = np.sqrt((x - x0)**2 + (y - y0)**2 + (z - z0)**2)
            x_copy = np.copy(x)
            for k in range(1, 99):
                lap_x = x[k]*2 - x[k-1] - x[k+1]
                lap_y = y[k]*2 - y[k-1] - y[k+1]
                lap_z = z[k]*2 - z[k-1] - z[k+1]
                total_dist = np.sqrt(lap_x**2 + lap_y**2 + lap_z**2)
                total_dist_list.append(total_dist)
                total_c_list.append(c[k])
                if abs(c[k] - total_dist) > 0.0:
                    x[k] += 0 #Wave(d[k])
                #x[k] += Wave(d[k])
            ax.plot(x, y, z, linewidth=0.2, color="brown")

            avgx += sum(x) / 100
            avgy += sum(y) / 100
            avgz += sum(z) / 100
            count += 1
    avgx /= count
    avgy /= count
    avgz /= count

    RADIUS = 0.3  # space around the head
    ax.set_xlim3d([avgx - RADIUS, avgx + RADIUS])
    ax.set_ylim3d([avgy - RADIUS, avgy + RADIUS])
    ax.set_zlim3d([avgz - RADIUS, avgz + RADIUS])
    ax.azim = -90
    ax.dist = 10
    ax.elev = 90
    ax.set_box_aspect([1,1,1])
    set_axes_equal(ax)
    plt.show()
    #fig.savefig('hair_strands.png', dpi=72)