We propose a new framework for estimating generative models via an adversar-ial process, in which we simultaneously train two models: a generative model G that captures the data distribution, and a discriminative model D that estimates the probability that a sample came from the training data rather than G.
The training procedure for G is to maximize the probability of D making a mistake. This framework corresponds to a minimax two-player game.
In the space of arbitrary functions G and D, a unique solution exists, with G recovering the training data distribution and D equal to 1 everywhere.
In the case where G and D are defined by multilayer perceptrons, the entire system can be trained with backpropagation. There is no need for any Markov chains or unrolled approximate inference networks during either training or generation of samples. Experiments demonstrate the potential of the framework through qualitative and quantitative evaluation of the generated samples.
본 논문에서 적대적Adversarial 과정을 통해 생성 모델을 추정하기 위한 새로운 모델을 제안했다. 데이터 분포를 감지하는 생성모델G과 샘플이 생성모델이 아닌 확률 훈련 데이터에서 나올 확률을 추정하는 판별모델D 두 모델을 동시에 학습시켰다.
G에 대한 훈련 과정은 D가 잘못 판별할 확률을 최대화하는 것이다. 이 프레임워크는 2인용 미니맥스 게임에 해당한다.
미니맥스Minimax 게임
결정이론, 게임이론, 통계학, 철학에서 사용하는 개념으로 최악의 경우 발생가능한 손실(최대 손실)을 최소화 한다는 규칙이다. 손실이 아니라 이익이 기준이라면 최소 이익을 극대화한다는 의미에서 "maximin" 이라고 부르기도 한다. 원래 두 명의 참가자가 존재하는 제로섬zero-sum 게임 이론으로부터 시작하였으나 (두 참가자가 순차적으로 행동하는 경우와 동시에 행동하는 경우 모두 포함), 더 복잡한 게임과 불확실성이 존재할 때의 일반적인 의사결정에 이르기까지 널리 쓰이고 있다.
임의의 함수 G와 D의 공간안에서 G가 훈련 데이터의 분포를 재구축하고, D는 모든곳에서나 1인 고유한 해가 존재한다.
G와 D가 MLP(다층 퍼셉트론)에서 정의되어 있는 경우, 전체 시스템은 역전파Backpropagation을 통해 학습될 수 있다. 샘플들을 생성하거나 학습시킬때 Markov chain이나 unrolled approximate inference networks 가 필요하지 않는다. 실험은 생성된 샘플들의 정성/정량적 평가를 통해 프레임워크의 가능성을 보여준다.
VAE는 input image X를 잘 설명하는 feature를 추출하여 Latent vector z에 담고, 이 Latent vector z를 통해 X와 유사하지만 완전히 새로운 데이터를 생성하는 것을 목표로 한다.
잠재 벡터 ( latent vector )
정규 분포나 균등 분포에서 독립적으로 개별적인 잠재 변수를 뽑아서 잠재 벡터를 생성 (인코더를 거쳐 잠재 벡터를 생각한다.)
이때 각 feature가 가우시안 분포Gaussian distributioin(=정규분포)를 따른다고 가정하고 latent z는 feature의 평균과 분산값을 나타낸다.
- p(z) : latent vector z의 확률밀도함수. 가우스 분포를 따른다 가정
- p(x|z) : 주어진 z에서 특정 x가 나올 조건부 확률에 대한 확률밀도함수
- θ : 모델의 파라미터
Input image X를 Encoder에 통과시켜 Latent vector z를 구하고, Latent vector z를 다시 Decoder에 통과시켜 기존 input image X와 비슷하지만 새로운 이미지 X를 찾아내는 구조이다.
VAE는 input image가 들어오면 그 이미지에서의 다양한 특징들이 각각의 확률 변수가 되는 어떤 확률 분포를 만들게 된다. 이런 확률 분포를 잘 찾아내고, 확률값이 높은 부분을 이용하면 실제에 있을법한 이미지를 새롭게 만든다.
모델의 파라미터θ가 주어졌을 때 우리가 원하는 정답인 x가 나올 확률이 p_θ(x) 높을 수록 좋은 모델이라고 할 수 있다. 즉 pθ(X)를 최대화 하는 방향으로 VAE의 파라미터θ를 학습시키게 된다.
VAE(Variational AutoEncoder)는 기존의 AutoEncoder와 탄생 배경이 다르지만 구조가 상당히 비슷해서 Variational AE라는 이름이 붙은 것이다. 즉, VAE와 AE는 엄연히 다르다.
AutoEncoder의 목적은 Encoder에 있다. AE는 Encoder 학습을 위해 Decoder를 붙인 것이기 때문이다. 반대로 VAE의 목적은 Decoder에 있다. Decoder 학습을 위해 Encoder를 붙인 것이기 때문이다.
VAE는 단순히 입력값을 재구성하는 AE에서 발전한 구조로 추출된 잠재 코드의 값을 하나의 숫자로 나타내는 것이 아니라, 가우시안 확률 분포에 기반한 확률값으로 나타낸다.
AE는 잠재 코드(latent code) 값이 어떤 하나의 값이 나오고(첫번째 그림 전자), VAE는 잠재 코드값이 평균과 분산으로 표현되는 가우시안(정규) 분포로 나온다.(첫번째 그림 후자)
따라서 AE가 만들어낸 잠재 공간은 군집이 비교적 넓게 퍼져있고(a), 중심으로 잘 뭉쳐있지 않지만, VAE가 만들어낸 잠재 공간은 중심으로 더 잘 뭉쳐져 있는 것을 확인 할 수 있다(b). 따라서 원본 데이터를 재생하는데 AE에 비해서 VAE가 더 좋은 성능을 보인다는 것을 알 수 있다. 즉 VAE를 통해서 데이터의 특징을 파악하는게 더 유리하다.
Adversarial modeling framework(적대적 모델링 프레임워크)는 모델이 모두 MLP(다층 퍼셉트론)일때 가장 간단하게 적용할 수 있다.
데이터 x에 대한 생성자 분포 p_g를 학습하기 위해 입력잡음변수 p_z(z)를 사전에 정의하고 난 후, 데이터 공간data space에 대한 매핑을 G(z; θ_g)로 나타낸다.
x~p_{data} : 원본 데이터 분포에서 한개의 데이터(이미지) x를 랜덤으로 뽑는다.
z~p_z : 노이즈 분포에서 한개의 노이즈 z를 랜덤으로 뽑는다.
또한 단일 스칼라를 output으로 출력하는 두번째 MLP D(x; θ_d)를 정의한다. D(x)는 x가 p_g(생성자로부터 나온 데이터)가 아닌 데이터에서 나온 확률을 나타낸다. 훈련 예제와 G로부터 나온 샘플에게 정답 레이블을 할당할assign 확률을 최대화하기 위해 D를 훈련한다. log(1-D(G(z)))를 최소화 하도록 G를 훈련한다.
즉, D와 G는 value function V(G,D)를 이용하여 2인 미니맥스 게임을 하는 것이라 할 수 있다.
위조 지폐범이 더욱더 진짜 같은 위조지폐를 만드는 과정을 단계별로 설명하면 다음과 같다.
이렇게 한번 경찰이 해당 지폐들을 구분을 완료하게 되면 한 번의 epoch 가 끝난다. 첫 번째 iteration에서는 경찰이 위조지폐와 진짜 지폐를 잘 구분할지라도, iteration을 거듭할수록 위조 지폐범은 더욱 비슷하게 생긴 위조지폐를 만들게 되고 경찰도 점차 위조지폐를 더 잘 구분할 수 있게 될 것이다.
그러다가 어느 순간 너무나도 완벽한 위조지폐가 탄생한다면, 경찰(D)은 결국 해당 지폐를 구분하지 못하기 때문에 이게 진짜인지 가짜인지 찍기 시작할 것이다. 확률은 둘 중 하나일 테니 결국 50%로 가 될 것이고, 그 순간 학습이 끝나게 된다.
Generator 입장에서는 Discriminator 가 진짜 이미지를 잘 맞추는지 못 맞추는지에 대해서 관심이 없다. 그냥본인이 만든 이미지가 얼마나 Discriminator를 속일 수 있냐가 중요하다. Generator 목적은 D(G(z)) 가 1이 되도록 하는 것뿐이다.
Discriminator 모델은 Supervised Learning, Generator 모델은 Unsupervised Learning이다.
D는 어떠한 input 데이터가 들어갔을 때, 해당 input 값이 어떤 것인지 Classify 하고, G는 어떤 latent code(잠재 코드)를 가지고 training 데이터가 되도록 학습하는 과정을 말한다.
먼저 가장 이상적인 상황에서의 D(Discriminator)입장을 생각해 보자. D가 보는 sample x가 실제로 data distribution 으로부터 온 녀석이라면 D(x)=1 이므로 첫번째 term에서 log값이 사라지고, G(z)가 만들어낸 녀석이라면 D(G(z))=0 이므로 두 번째 term 역시 0으로 사라진다. 이 때가 D의 입장에서 V의 "최대값"을 얻을 수 있다.
- black : 원본 데이터 분포, green : 생성 모델 분포, blue : 판별 모델 분포
GAN(Generative Adversarial nets)는 동시에 discriminative distribution(파란점선, D)을 업데이트함으로써 data generating distribution(검은 점선) p_{data}의 샘플과 generative distribution(녹색 선, G) p_g의 샘플을 구별하도록 훈련된다. 아래의 x선과 z선은 각각 x와 z의 도메인을 나타내며, 위로 뻗은 화살표가 x=G(z)의 mapping을 보여준다.
(a) 학습이 이뤄지지 않았을 때는 생성자 분포(녹색 선)이 원본 분포(검은 점선)를 잘 학습하지 못하기 때문에, distriminator가 생성자 분포와 원본 분포를 잘 구별한다.
(b) 알고리즘의 내부 루프에서 D는 데이터로부터 원본(p_{data})샘플과 생성자(p_g)샘플 구별하도록 훈련되어 D*(x) = p_data(x)/((p_data(x) + p_g(x))로 수렴한다.
(c) G로 업데이트한 후에, D의 기울기는 G(z)가 데이터로 분류될 가능성이 더 높은 데이터로 분류될 가능성이 더 높은 영역으로 갈 수 있게끔 돕는다.
(d) 여러 단계의 훈련 후에, 만약 G와 D가 충분히 학습되었다면 생성데이터의 분포가 원본데이터의 분포와 같아진다. p_g가 p_{data}에 수렴하게 된다. D는 가짜이미지와 진짜 이미지를 구별할 수 없게 되기에 판별모델의 분포는 1/2로 수렴하게 된다.
논문에서 한 가지 실용적인 tip이 나오는데, 위에 value function에서 log(1−D(G(z))) 부분을 G 에 대해 minimize하는 대신 log(D(G(z)))를 maximize하도록 G를 학습시는게 더 빠르다. (나중에 저자가 밝히 듯이 이 부분은 전혀 이론적인 동기로부터 수정을 한 것이 아니라 순수하게 실용적인 측면에서 적용을 하게 된 것이다.)
이유도 아주 직관적인데 예를 들어 학습 초기를 생각해보면, G가 초기에는 아주 이상한 image들을 생성하기 때문에 D가 너무도 쉽게 이를 real image와 구별하게 되고 따라서 log(1−D(G(z))) 값이 매우 saturate하여 gradient를 계산해보면 아주 작은 값이 나오기 때문에 학습이 엄청 느리다.
하지만 문제를 G=argmaxGlog(D(G(z)))로 바꾸게 되면, 초기에 D 가 G로 나온 image를 잘 구별한다고 해도 위와 같은 문제가 생기지 않기 때문에 원래 문제와 같은 fixed point를 얻게 되면서도 stronger gradient를 줄 수 있는 상당히 괜찮은 해결방법이다.
GAN의 목적식에서 D와 G의 목표는 다음과 같다.
1) Generator 입장 : 생성자 분포가 원본 분포와 같아진다. → p_g가 p_{data}로 수렴한다.
2) Discriminator 입장 : 가짜 이미지와 진짜 이미지를 구별할 수 없게 된다. → D(G(z))가 1/2로 수렴한다.
Proposition 1
G가 fixed 된 상태에서, D의 optimal point가 p_data(x)/{p_data(x)+p_g(x)}로 수렴됨을 보여보도록 한다.
D*로 가는 과정은 도메인 z에서 x로 mapping되는 과정이다. g(z) -> x , p_z(z) -> p_g(x)
a*log(y) + b*log(1-y) 형태이므로 y로 미분해서 도함수가 0을 가질때가 optimum이고, 개형에서 알 수 있다시피 optimum이지점인 a/(a+b)에서global maximum을 가지게 된다.
Proposition2
global optimal에서 p_g 가 p_data에 수렴한다. (p_g=p_data)
global optimal이므로 D*(G(z)) = p_g(x) / { p_{data}(x) + p_g(x) } 그리고 D*(x) = p_{data}(x) / { p_{data}(x) + p_g(x) } 로 수렴하게 된다. (by propostion 1)
C(G)는 마지막 항으로 전개가 가능하며 global optimal point를 얻기 위해선 JSD가 0이 되어야 한다. JSD는 p_{data}=p_g일때 0이 되게 된다.
따라서 global optimal point에서 p_{data} = p_g 이다.
global optimal point에서 p_{data} = p_g 이고(by proposition2), D*(G(z)) = p_g(x) / { p_{data}(x) + p_g(x) } (by proposition1) 이다. 따라서 D*(G(z)) = p_g(x) / 2*p_g(x) = 1/2 이 되게 된다.
앞서 증명한 것은 생성자 판별자 각각에 대해서 global optimum point가 존재할 수 있다는 것에 대한 증명이었다. 학습이 되어서 global optimal에 잘 도달할 수 있는가?라는 질문은 다른 내용이다.
사실 GAN은 학습이 어려운 네트워크이다. 추후 다른 GAN 논문에서 학습의 안정성을 더할 수 있는 다양한 테크닉이 추가적으로 나온다.
ⓛ 학습 횟수만큼 반복할 수 있도록 해준다. (epoch 설정)
② Discriminator 학습
m개의 noise와 원본 데이터를 sampling 한 후, 기울기 값을 구한 뒤 경사를 타고 ascending 해서 식의 값을 maximize하는 형태로. 원본 이미지에 대해선 1의 값을 생성된 이미지에 대해선 0을 내보내도록 학습한다.
③ Generator 학습
m개의 noise를 smapling하여 m개의 생성된 이미지를 만든 후, 기울기값을 minmize하게끔 학습한다.
④ 기울기는 gradient-based 방법으로 학습된다. 본 논문에서는 momentum을 사용하였다.
본 논문에서 G, D 학습은 MNIST, TFD(Toronto Face Database), CIFAR-10 데이터넷으로 진행하였다.
p_g에 있는 test set data들의 확률은 G로 생성된 데이터들에 Gaussian Parzen window을 맞추는 방식으로 측정하며 측정된 확률은 p_g에 속한 log 확률로 변환한다. 이 때 Gaussian의 표준편차 σ는 검증 데이터셋의 교차 검증을 통해 값을 정합니다. 해당과정은 실제 확률을 다루기 힘든 generative 모델들을 평가할 때 사용하는 방식이다.
다른 생성모델과 비교하였을 때 좋은 성능을 낸다. (최대우도가 높을 수록 좋은 성능을 가진다)
이미지는 랜덤하게 만든것을 그림으로 넣은 것이다.
마지막 칼럼의 데이터는 왼쪽 데이터와 가장 가까운 nearest neigbor에 해당하는 학습 데이터이다. 확인해보면 만들어진 이미지와 엄밀한 차이가 있다. -> 정말 있을법한 이미지를 잘 만들어 낸다.
AE 계열의 다른 생성모델과 비교하였을 때 상대적으로 sharp한 이미지가 잘 나왔다.
https://ichi.pro/ko/vae-variational-auto-encoder-leul-sayonghan-saengseong-modelling-277371603749134
https://m.blog.naver.com/euleekwon/221557899873
https://jaejunyoo.blogspot.com/2017/01/generative-adversarial-nets-2.html
댓글 영역