1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
| class VAE(torch.nn.Module): def __init__(self, input_dim, hidden_dims, decode_dim=-1, use_sigmoid=True): ''' input_dim: The dimensionality of the input data. hidden_dims: A list of hidden dimensions for the layers of the encoder and decoder. decode_dim: (Optional) Specifies the dimensions to decode, if different from input_dim. ''' super().__init__()
self.z_size = hidden_dims[-1] // 2
encoder_layers = [] decoder_layers = [] counts = defaultdict(int)
def add_encoder_layer(name: str, layer: torch.nn.Module) -> None: encoder_layers.append((f"{name}{counts[name]}", layer)) counts[name] += 1 def add_decoder_layer(name: str, layer: torch.nn.Module) -> None: decoder_layers.append((f"{name}{counts[name]}", layer)) counts[name] += 1 input_channel = input_dim encoder_dims = hidden_dims for x in hidden_dims: add_encoder_layer("mlp", torch.nn.Linear(input_channel, x)) add_encoder_layer("relu", torch.nn.LeakyReLU()) input_channel = x
decoder_dims = encoder_dims[::-1] input_channel = self.z_size for x in decoder_dims: add_decoder_layer("mlp", torch.nn.Linear(input_channel, x)) add_decoder_layer("relu", torch.nn.LeakyReLU()) input_channel = x self.fc_mean = torch.nn.Sequential( torch.nn.Linear(encoder_dims[-1], self.z_size), torch.nn.LeakyReLU(), torch.nn.Linear(self.z_size, self.z_size), ) self.fc_var = torch.nn.Sequential( torch.nn.Linear(encoder_dims[-1], self.z_size), torch.nn.LeakyReLU(), torch.nn.Linear(self.z_size, self.z_size), ) self.encoder = torch.nn.Sequential(OrderedDict(encoder_layers)) self.decoder = torch.nn.Sequential(OrderedDict(decoder_layers)) self.out_layer = torch.nn.Sequential( torch.nn.Flatten(), torch.nn.Linear(decoder_dims[-1], input_dim), torch.nn.Tanh(), )
def encode(self, x): res = self.encoder(x) mean = self.fc_mean(res) logvar = self.fc_var(res) return mean, logvar
def reparameterize(self, mean, logvar, n_samples_per_z=1): d, latent_dim = mean.size() device = mean.device
std = torch.exp(0.5 * logvar).to(device) epsilon = torch.randn(d, latent_dim, device=device)
z = mean + std * epsilon
return z
def decode(self, z): probs = self.decoder(z) out = self.out_layer(probs) return out
def forward(self, x, n_samples_per_z=1): mean, logvar = self.encode(x)
batch_size, latent_dim = mean.shape if n_samples_per_z > 1: mean = mean.unsqueeze(1).expand(batch_size, n_samples_per_z, latent_dim) logvar = logvar.unsqueeze(1).expand(batch_size, n_samples_per_z, latent_dim)
mean = mean.contiguous().view(batch_size * n_samples_per_z, latent_dim) logvar = logvar.contiguous().view(batch_size * n_samples_per_z, latent_dim)
z = self.reparameterize(mean, logvar, n_samples_per_z) x_probs = self.decode(z)
x_probs = x_probs.reshape(batch_size, n_samples_per_z, -1) x_probs = torch.mean(x_probs, dim=[1])
return { "imgs": x_probs, "z": z, "mean": mean, "logvar": logvar }
hidden_dims = [128, 64, 36, 18, 18] input_dim = 256 test_tensor = torch.randn([1, input_dim]).to(device)
vae_test = VAE(input_dim, hidden_dims).to(device)
with torch.no_grad(): test_out = vae_test(test_tensor) print(test_out)
|