본문 바로가기
인공지능/CV

GAN에서의 미분 (Pytorch)

by EXUPERY 2021. 10. 13.
반응형

GAN에서의 미분 (Pytorch)

 

 

https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

Pytorch에서 제공하는 DCGAN Tutorial에서 훈련스텝은 아래 코드와같다.

for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

Tensorflow에서 제공하는 DCGAN Tutorial에서의 미분과정보다 코드가 더 복잡해보인다. Pytorch는 어떤 값을 자동미분해줄지 직접 작성해주어야하고, 한번 미분한 값은 computation이 없어지기 때문이다.

 

netD.zero_grad() # 판별자 Gradient 초기화

# 진짜이미지에대한 기울기구하기
real_img = data[0].to(device)
b_size = real_img.size(0) # 가독성을위해 real_cpu -> real_img로 변수명 변환
label = torch.full((b_size,), real_label, dtype=torch.float, device=device) # 라벨을 이미지의 사이즈만큼 만들어줌

output = netD(real_img).view(-1) # 판별자가 진짜이미지를 판단함
errD_real = criterion(output, label) # 1로 판단하는지의 손실을 구한다

errD_real.backward()	# 진짜이미지에대한 미분

# 가짜이미지에대한 기울기구하기
fake = netG(noise)	# 생성기에서 가짜이미지생성
label.fill_(fake_label)	# 라벨링 0
output = netD(fake.detach()).view(-1)	# 가짜이미지 detach() 후 판별
errD_fake = criterion(output, label)	# 0으로 판단하는지의 손실을 구한다

errD_fake.backward()	# 가짜이미지에대한 미분

optimizerD.step()	# 최적화

Discriminator의 훈련과정만 뺀 코드이다. 중요한것은 Generator에 이미지를 넣어 생성된 출력이미지(fake)를 detach해준다. 현재는 판별자의 train step이고, 판별자는 라벨링된 이미지를 통해 훈련하는 것이 전부이기때문에 생성자의 연산까지 가져올 필요가 없다. 또한 가져온다면 다음에 사용하지 못한다. (retain_graph=True를 사용하지 않는 한)

 

netG.zero_grad()
label.fill_(real_label) # 생성자에게는 판별자가 진짜라고 판단하는 것이 라벨이다
output = netD(fake).view(-1)	# 판별자의 판단에 따라서 생성자를 훈련시키는 것. detach가 없다.

errG = criterion(output, label)	# 판별자의 결과를 통한 훈련
errG.backward()	# 손실 미분
optimizerG.step()	# 최적화

Generaotr 훈련코드를 보면, fake에서 detach가 사라져있다. Discriminator의 결과를 통해 Generator를 훈련시켜야하기 때문이다.

 

detach는 기존의 텐서를 참조하는 것이 아니라 새로운 텐서를 반환한다.

반응형

댓글