상세 컨텐츠

본문 제목

[Advanced ML & DL Week5] Generative Adversarial Nets

심화 스터디/Advanced ML & DL paper review

by avril22 2022. 11. 10. 01:13

본문

작성자: 15기 우명진

 

1. Introduction

minimax two player game

generator model(위조범, 거짓 화폐를 생산함) vs. discriminative model(경찰, 위조 화폐를 구별해냄)

→ 두 모델은 서로 겅쟁을 통해 성능을 향상 시켜 나간다. (Advsersarial Model)

 

 

 

 

GAN 모델은 unsupervised learning으로, generator model은 노이즈로부터 실제 데이터와 유사한 가짜 이미지를 생성한다. 그리고 discriminator model은 fake image를 실제 데이터와 구별하여 판별하도록 학습한다. 

 

-Generative Model과 Discriminative Model은 모두 다층 퍼셉트론이어서 Markov Chain을 사용하지 않고, backpropagation과 dropout algorithm을 통한 학습이 가능하다. 

-Generative Model은 fake image를 들킬 확률이 최소화되어야 하며, Discriminative Model은 정확하게 판별할 확률이 높아야 한다. 

 


2. Related Work 

3. Adversarial Nets

 

목적함수를 통한 GAN 이해

G는 목적함수 V를 최소화하고, D는 목적함수 V를 최대화한다

 

 

> Generator 목적 관점

  • pz(z): noise z가 주어졌을 때
  • loss인 log(1-D(G(z))를 최소화
  • 이는 D(가짜 이미지)가 진짜(1)일 확률을 높이는 것과 같은 원리

→ 생성한 가짜 데이터가 들킬 확률을 최소화하는 것이 목표!

 

 

> Discriminator 목적 관점 

  • pdata(x): 실제 데이터가 주어졌을 때 
  • loss인 log(D(x))을 최대화
  • 이는 D(진짜 이미지)일 때 진짜(1)라고 판단할 확률을 높이는 것과 같은 원리

 

  • pz(z): noise z가 주어졌을 때 
  • loss인 log(1-D(G(z))를 최대화
  • 이는 D(가짜 이미지)가 가짜(0)이라고 판단할 확률을 높이는 것과 같은 원리 

→ 진짜 데이터와 가짜 데이터를 정확하게 판단할 확률을 최대화하는 것이 목표!

 


분포를 통한 GAN 학습 과정 이해 

 

-z: generator를 생성해내는 noise, z는 uniform distribution

-초록선 : generative distribution

-검은 점선: 실제 이미지의 분포 

-파란 점선: discriminative distribution

+) GAN가 학습할 때는 Generator와 Discriminator 중에서 하나를 고정하고 하나만 학습함 

 

 

(a) uniform distribution인 z에서 x(이미지 공간)으로 mapping이 발생함 → generator 학습 

  처음의 generator는 원래 이미지 분포와 달리 정확하지 않은 분포를 생성 ex) 원하지 않은 이미지가 출력됨

 

(b) generator가 고정되어 있는 상태에서, discriminator를 학습하였다. 

  학습한 discriminator는 (a)에 비해서 안정적으로 학습하는 것을 알 수 있고, 진짜 이미지에 대해서 1을 반환함.

 

(c) discriminator가 고정되어 있는 상태에서, generator를 학습하였다. 

  z에서 x로의 mapping이 (a)에 비해서 실제 이미지 분포에 가깝도록 학습되었다. 

 

(d) GAN 학습 과정을 계속 반복하여 나타나는 결과 

      1) z에서 x로의 mapping이 실제 이미지와 거의 동일한 분포가 생성이 된다. 

     2) discriminator는 0.5의 확률을 반환(실제 이미지와 가짜 이미지 판별할 확률이 반반) 

→ 이후 discriminator의 성능은 더 이상 높아지지 않는다

 


 

4. Theoretical Results

 

5. Experiments

  • MNIST data

 

 

 

 

 


 

GAN 코드 분석

 

> Generator, Discriminator Class 생성 

  • Generator
    • latent_dim: z의 벡터 차원
    • 100 -> 128 -> .. -> 1024로 이미지 차원에 가깝도록
    • 이미지를 fixel로 flat하면 길이가 1024
    • Tanh: 이미지의 값을 -1~1로 normalize해주기 때문에 이것을 맞추어주기 위하여 사용 
  • Discriminator
    • 1024의 값을 받은 후, 확률을 output으로 도출하는 Sigmoid함수 

 

 



> Loss Function, Optimizer

  • loss function: Binary Cross Entropy Loss(BCELoss)
  • G와 D에 대해서 optimizer를 따로 생성함 

 

 

> Training

  • valid는 실제 이미지인 경우(fill = 1)
  • fake는 가짜 이미지인 경우(fill = 0)

 

 

 

> Train Generator

  • Generator가 학습할 때는 Discriminator가 고정된다.
  • z의 사전 확률 분포는 가우시안 분포이고, random한 값 하나를 뽑게 된다
  • Generator에 넣게 되면, 1024라는 이미지의 값 나옴
  • discriminator는 1에 가깝게 학습을 하고 싶기 때문에 'valid' 값과 비교하게 된다. 즉 gen_imgs의 값이 1이 되도록 학습한다

 

 

 

 

> Train Discriminator

  • Discriminator가 학습할 때는 Generator가 고정된다.
  • Generator는 두가지 학습을 함.
    • 1) 실제 이미지가 valid(1)에 가깝도록 학습을 함.
    • 2) generator가 생성한 가짜 이미지가 fake(0)에 가깝도록 학습을 함. 이때 사용하는 데이터는 Discriminator가 생성한 가짜 데이터이다. 

 

 


6. Advantages & Disadvantages

장점

1) Markov Chain이 필요하지 않고 backpropagation을 이용해서만 gradient를 계산한다. (한번에 샘플을 생성 가능)

2) adversarial model은 generator의 파라미터가 실제 데이터에 의해 직접 업데이트 하는 것이 아니라 discriminator의 gradient에 의해 업데이트를 하기 때문에 통계적 이점을 얻는다

3) degenerate(변형된) distribution를 표현할 수 있다

 

 

단점 

1) Mode collapse: 분포의 형태를 전반적으로 맵핑하기보다, 단순히 오류를 최소화하기 위해서 최빈값에만 집중하여 학습함에 따라 실제 값(이미지 등) 중 특정한 형태만 생성

→ mini-batch discrimination, feature matching 등 도입하여 해결 

2) 어떻게 평가할 것인가? : 생성된 결과가 잘 생성된 것인지 판단할 수 있는 지표가 없고, 이에 따라 학습을 얼마나 진행해야 하는지 명확한 기준이 부족하다 

→ Inception Score(생성된 이미지의 다양성을 측정하는 지표) 사용하여 해결 

3) Unstable Training

서로 이기려는 minimax 게임을 통해 학습하기 때문에 G와 D 간의 힘의 균형이 깨지기 쉬움

→ DCGAN 등 학습을 안정적으로 바꾸고자 구조를 새로 제안하는 모델 등장 

 

 

 

 


 

참고한 자료들:

https://wegonnamakeit.tistory.com/54 

 

GAN(Generative Adversarial Networks) 논문 리뷰

01. Taxonomy of Machine Learning GAN 모델을 설명하기 전에 딥러닝을 크게 두 가지로 나누면, 1) Supervised Learning과 2) Unsupervised Learning이 있다. A. Supervised Learning 지도 학습 대표적인 모델로 Discriminative Model

wegonnamakeit.tistory.com

https://hyeongminlee.github.io/post/gan001_gan/ 

 

[GAN]Generative Adversarial Network | Hyeongmin Lee's Website

이번 포스트에서는 GAN의 기본 개념과 원리에 대해 알아보도록 하겠습니다. GAN(Generative Adversarial Network)은 Generator와 Discriminator의 경쟁적인 학습을 통해 Data의 Distribution을 추정하는 알고리즘입니

hyeongminlee.github.io

https://www.youtube.com/watch?v=jB1DxJMUlxY 

코드 링크:

https://github.com/eriklindernoren/PyTorch-GAN

 

관련글 더보기

댓글 영역