반응형
GAN에서의 미분 (Pytorch)
https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
Pytorch에서 제공하는 DCGAN Tutorial에서 훈련스텝은 아래 코드와같다.
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch, accumulated (summed) with previous gradients
errD_fake.backward()
D_G_z1 = output.mean().item()
# Compute error of D as sum over the fake and the real batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
Tensorflow에서 제공하는 DCGAN Tutorial에서의 미분과정보다 코드가 더 복잡해보인다. Pytorch는 어떤 값을 자동미분해줄지 직접 작성해주어야하고, 한번 미분한 값은 computation이 없어지기 때문이다.
netD.zero_grad() # 판별자 Gradient 초기화
# 진짜이미지에대한 기울기구하기
real_img = data[0].to(device)
b_size = real_img.size(0) # 가독성을위해 real_cpu -> real_img로 변수명 변환
label = torch.full((b_size,), real_label, dtype=torch.float, device=device) # 라벨을 이미지의 사이즈만큼 만들어줌
output = netD(real_img).view(-1) # 판별자가 진짜이미지를 판단함
errD_real = criterion(output, label) # 1로 판단하는지의 손실을 구한다
errD_real.backward() # 진짜이미지에대한 미분
# 가짜이미지에대한 기울기구하기
fake = netG(noise) # 생성기에서 가짜이미지생성
label.fill_(fake_label) # 라벨링 0
output = netD(fake.detach()).view(-1) # 가짜이미지 detach() 후 판별
errD_fake = criterion(output, label) # 0으로 판단하는지의 손실을 구한다
errD_fake.backward() # 가짜이미지에대한 미분
optimizerD.step() # 최적화
Discriminator의 훈련과정만 뺀 코드이다. 중요한것은 Generator에 이미지를 넣어 생성된 출력이미지(fake)를 detach해준다. 현재는 판별자의 train step이고, 판별자는 라벨링된 이미지를 통해 훈련하는 것이 전부이기때문에 생성자의 연산까지 가져올 필요가 없다. 또한 가져온다면 다음에 사용하지 못한다. (retain_graph=True를 사용하지 않는 한)
netG.zero_grad()
label.fill_(real_label) # 생성자에게는 판별자가 진짜라고 판단하는 것이 라벨이다
output = netD(fake).view(-1) # 판별자의 판단에 따라서 생성자를 훈련시키는 것. detach가 없다.
errG = criterion(output, label) # 판별자의 결과를 통한 훈련
errG.backward() # 손실 미분
optimizerG.step() # 최적화
Generaotr 훈련코드를 보면, fake에서 detach가 사라져있다. Discriminator의 결과를 통해 Generator를 훈련시켜야하기 때문이다.
detach는 기존의 텐서를 참조하는 것이 아니라 새로운 텐서를 반환한다.
반응형
'인공지능 > CV' 카테고리의 다른 글
[논문리뷰] CoCa: Contrastive Captioners are Image-Text Foundation Models (0) | 2022.06.06 |
---|---|
인공지능이 만드는 폰트 [ HAN2HAN : Hangul Font Generation] (0) | 2021.11.13 |
YOLOv3를 이용한 턱스크찾기 프로젝트 (3) | 2021.05.01 |
객체탐지 (Object Detection) 2. YOLO !! (v1~v3) (2) | 2021.05.01 |
객체탐지 (Object Detection) 1. YOLO 이전 까지 흐름 (0) | 2021.05.01 |
댓글