이번 글에서는 2023년 AISTATS에서 발표된 Adversarial Random Forests for Density Estimation and Generative Modeling 논문에 대해 다뤄보도록 하겠습니다.

(랜덤포레스트 모델에 기반해서 데이터 생성할 수 있다는 것이 놀라웠고, 이 분야쪽은 아직 활발히 진행된 것 같지 않아서 깊게 파보면 좋은 모델을 만들어내는데 도움이 될 것 같은 느낌이 든다,,!)

Abstract

저자는 비지도 방식의 Random forest를 활용하여 데이터 생성하고, 밀도 추정에 대한 방법론을 제시하고자 합니다. GAN 모델에 영감을 받아서 생성자와 판별기가 반복적으로 학습하여 트리가 데이터의 구조적 특징을 학습하는 절차를 수행합니다.

이 방법은 최소한의 가정만으로 consistent를 보장한다고 강조하고 있으며, 저자의 방법론은 smooth한 밀도 추정을 제공하여 조금 더 자연스러운 데이터 생성이 가능하다고 합니다.

저자는 Tabular data를 생성하는 딥러닝 모델과 성능을 비교하고 있으며, 실행 속도가 평균적으로 훨씬 빠르다는 점을 강조하고 있습니다.

크게 알고리즘의 형태를 보자면

ARF → FORDE → FORGE 이 순서로 알고리즘이 진행됩니다.

ARF : 분류기 학습

FORDE : ARF에서 학습한 분류기를 통해 각 리프노드에서의 파라미터 추정

FORGE : 추정된 파라미터를 통해 데이터 생성

이제는 ARF 알고리즘에 대해 설명할 것입니다.(분류기 학습하는데 중점을 둔 알고리즘에 대해서!)

ADVERSARIAL RANDOM FORESTS

이제 저자가 제안한 방법론에 대해서 자세하게 다뤄보겠습니다.

ARF(Adversarial Random Forest)는 URF(Unsupervised Random Forest)의 재귀적 변형으로, 각 리프 노드에서 데이터가 서로 독립적인 상태(jointly independent)가 되도록 하는 것이 목표입니다. 아마도 변수별로 독립적인 패턴을 파악하여 간결하고 일반화된 모델을 만드는 데 도움이 될 것 같습니다.

(다음 글에서 ARF 알고리즘의 pseudo code를 직접 뜯어보면서 더 자세하게 다룰 것입니다. 아래는 알고리즘의 흐름)

초기 단계: 먼저 기존의 URF를 합성 데이터 $\tilde{X}^{(0)}$로 학습하여 분류기 $f^{(0)}$을 생성합니다. 만약 분류기의 정확도가 50%이상이라면(생성데이터와 원본데이터의 구분 확률) 리프노드의 coverage를 계산하게 됩니다. 이때 coverage는 원본데이터에 대해서 계산하게 됩니다.(coverage라는건 원본데이터가 각 리프노드에 들어간 비율을 말합니다.) 각 리프 노드의 커버리지를 계산한 후, 이 비율에 따라 리프 노드를 무작위로 선택하고, 그 안의 마진으로부터 새로운 합성 데이터셋 $\tilde{X}^{(1)}$을 생성합니다. (커버리지가 높을수록 그 리프에서의 데이터 생성 개수는 많아짐) 생성된 합성 데이터 $\tilde{X}^{(1)}$과 원본 데이터 $X$를 구분하는 새로운 분류기 $f^{(1)}$을 학습합니다.(ARF 알고리즘에서 분류기를 학습하고 나서 그 분류기를 통해 밀도추정(FORDE)과 데이터 생성(FORGE)하는 것임! ARF는 분류기 학습에 중점을 두는 것이지 데이터 생성하는것이 목적이 아닌것을 알아두자!)

새로운 분류기의 OOB(Out-of-Bag) 정확도가 충분히 낮아지면 ARF가 수렴한 것으로 간주하고, $f^{(1)}$를 최종 모델로 사용합니다. 그렇지 않다면, 위의 과정을 반복하여 새로운 합성 데이터셋을 생성하고 분류기를 학습합니다. → 수렴할때까지!

  • 수렴의 의미
    • 수렴한다는 것은 ARF의 분류기가 더 이상 합성 데이터와 원본 데이터를 구분할 수 없게 되었다는 것을 의미합니다. 즉, 합성 데이터가 원본 데이터와 매우 유사해져서 분류기의 정확도가 대략 50%에 도달한 상태를 말합니다. 이때를 수렴한다고 말하고 있습니다.
  • OOB(Out-of-Bag) 정확도란?
    • 랜덤 포레스트(Random Forest)에서 모델 성능을 평가할 때 사용하는 방법이며, 선택되지 않은 샘플을 활용하여 모델의 성능을 측정합니다. (부트스트랩 샘플링을 사용하면 데이터의 63%정도가 각 트리의 학습에 사용되고, 나머지 37%는 사용되지 않습니다. 각 트리에서 n개의 샘플을 복원추출로 뽑는다고 생각하면 한개의 샘플이 뽑히지 않을 확률은 (1-$\frac{1}{n}$)입니다. 트리가 무한대로 증가하면 $\lim\limits_{n\rightarrow \infin}(1-\frac{1}{n})^n$=$e^{-1}$ ≈ 0.3678794 )
    • 각 트리에 대해 OOB 샘플을 사용해 정확도를 측정하고, 모든 트리의 결과를 평균하여 최종 OOB 정확도를 계산합니다. → 충분히 낮다면 ARF수렴

ARF(Adversarial Random Forest)는 GAN과 유사한 구조를 가지고 있지만 몇 가지 중요한 차이점이 있습니다.

  • 유사점
    • ARF에서의 "생성자"는 marginal distribution에서 샘플링하여 데이터를 생성하는 역할을 하고, "판별자"는 랜덤 포레스트 분류기입니다. 이 둘은 서로 번갈아 가며 레이블 불확실성을 높이고 낮추는 제로섬 게임 형태로 작동합니다.
  • 차이점
    • 파라미터 공유: ARF에서는 생성자와 판별자가 동일한 파라미터를 공유합니다. GAN에서는 생성자와 판별자가 뉴럴네트워크를 통해 파라미터를 학습했지만, ARF에서 생성자는 스스로 학습하지 않고 판별자가 학습한 정보를 활용하는 방식입니다.
    • ARF모델에서 합성 데이터는 원본데이터들을 이용하여 복원추출한 데이터입니다. 즉, 새롭게 생성한 데이터가 아닌 것입니다! → ARF 알고리즘에서는 데이터를 독립적으로 만드는 leaf를 학습하는데 중점을 두고있습니다. 새로운 데이터를 생성하는 것은 FORDE를 통해 분포에 대한 parameter를 학습하고 FORGE에서 새로운 데이터를 생성합니다!
    • ARF는 한 번의 반복만으로도 데이터의 독립성을 유도할 수도 있다고 합니다.

 

위에서는 ARF 모델에 대한 흐름을 설명하였고 이제는 각 리프노드에서 변수들 간의 독립성을 만족시키기 위한 조건에 대해 살펴보겠습니다! (가장 중요한 부분! 변수간 독립이 만족하지 못하면 의미가 없음)

ARF(Adversarial Random Forest)가 데이터의 독립성을 확보하고 수렴하는 데 필요한 조건과 가정에 대해 설명하겠습니다.

Local Independence Criterion

  • 목표: ARF의 목표는 각 트리의 모든 리프 노드에서 데이터가 독립적으로 분포하도록 분할을 수행하는 것입니다. 모든 트리 b, 리프 노드 $\ell$, 그리고 샘플 x에 대해 다음을 만족하는 분할 세트 Θ를 찾는 것입니다. $p(x|\theta^\ell_b)=\prod^d_{j=1}p(x_j|\theta^\ell_b)$. 즉, 각 리프 노드에서 데이터의 joint probability가 각 변수들의 곱으로 표현될 수 있음을 의미하며 각 노드 내에서 변수들이 독립이라는 것을 말할 수 있습니다.

Assumption 1 : 특성 도메인 가정

  • 특성 공간 제한: 특성 공간이 $X=[0,1]^d$로 제한되어 있으며, 이 범위 내에서 joint density p가 0과 ∞ 사이에서 안정적으로 유지된다는 가정입니다. 이는 데이터가 모든 특성에 대해 정상적인 분포를 가지고 있고, 특성 값들이 정해진 범위 내에서 벗어나지 않음을 의미합니다.

Assumption 2 : Lipschitz 연속성 가정

  • Lipschitz 연속성: 각 라운드마다 목표 함수 P(Y=1∣x)가 Lipschitz 연속성을 만족해야 한다는 가정입니다.
    • Lipschitz 연속성은 함수의 변화 속도가 특정한 상수(즉, Lipschitz 상수)로 제한된다는 것을 의미합니다. Lipschitz 상수가 매 라운드마다 바뀔 수 있지만, 이 값이 $\frac{1}{max_{\ell,b}(diam(X^\ell_b))}$
    보다 빠르게 증가하지 않는다는 제한을 둡니다.
    • 이 가정은 데이터 분포가 과도하게 변화하지 않도록 하여 안정적인 학습을 보장합니다.

Assumption 3 : 트리 구성 및 학습에 관한 조건

이 가정은 ARF에서 사용되는 트리 구조에 대한 여러 가지 조건을 설정합니다.

  • (i) 트레이닝 데이터 분할: 각 트리를 학습할 때, 데이터를 두 부분으로 나눕니다:
    • 하나는 분할 기준을 학습하기 위한 부분이고,
    • 다른 하나는 리프 노드에 레이블을 할당하기 위한 부분입니다.
  • (ii) 트리의 성장 방식:
    • 각 트리는 부트스트랩 샘플링이 아닌 subsamples(중복 허용x)를 사용해 학습됩니다.
      • 부트스트랩 샘플링으로 학습되면 과적합이 발생할 수도 있기 때문!
      • 이라고 하는데, 실제 코드 뜯어보니 부트스트랩 샘플링으로 트리를 학습하는 것 같은데,,, 혹시나 보신 분이 있다면 의견 부탁드립니다. (_ _)
      • 트리 학습할 때 subsampling하려면 ranger함수에서 수정을 해야할 듯 함(뇌피셜)
    • subsamples의 크기 $n_b$(트리b에서 학습에 사용되는 데이터 수)는 n보다 작아야 합니다. ($n_b$→∞, $n_b$/**$n$**→0 as **$n$**→ ).
  • (iii) 분할 확률:
    • 각 내부 노드에서 특정 특성 $X_j$를 분할할 확률은 최소 π>0으로 제한됩니다.
      • 예를 들어 키와 몸무게에 대한 변수가 있다면, 처음 루트노드에서 키를 선택할지 몸무게를 선택할지에 대한 확률을 부여해줘야 합니다. 즉, 특정 특성에만 의존하지 않도록 합니다.
  • (iv) 분할의 균형성:
    • 각 분할에서 두 자식 노드에 들어가는 데이터 비율은 최소한 γ∈(0,0.5]이어야 합니다.
      • 이 부분은 자식노드에 최소한의 비율을 보장하기 위해서 설정한 것입니다. 만약 설정하지 않게되면 두개의 자식노드의 비율이 0.99, 0.01 이렇게 분할이 되는 상황이 발생할 수도 있기 때문입니다.
  • (v) 리프 노드의 개수:
    • 각 트리 b에 대해 리프 노드의 총 개수 $L_b$는 ∞로 향해야 하지만, 전체 데이터 수 n에 비해 작아야 합니다. ($L_b$→∞, $L_b$/**$n$**→0 as **$n$**→ ).
  • (vi) 소프트 레이블:
    • 각 노드의 예측 확률을 투표가 아닌 평균을 통해 결정한다는 의미입니다.

ARF는 특성 독립성을 확보하고 데이터를 효과적으로 분할하기 위해 위의 가정들(A1~A3)을 만족하는 환경에서 작동합니다. 이러한 가정들은 ARF가 수렴하고, 리프 노드에서 데이터의 독립성을 확보하여 밀도 추정 및 데이터 생성에 적합한 모델을 학습하는 데 필수적인 역할을 합니다.

  • 데이터의 변수들이 독립이 되는것을 아주 강조하는걸 보니 독립이냐 아니냐에 따라 모델의 성능을 좌지우지하는 것 같다는 생각이 드네요(어쩌면 독립이 아니라면 성능이 너무 낮게 나올수도)

 

Density Estimation and Data Synthesis

ARF에서 변수별 독립을 만족시키기 위한 가정들을 살펴보았고, ARF알고리즘을 통해 분류기를 생성하고 그 분류기를 활용하여 Density 추정과 데이터 생성하는 파트에 대해 설명하겠습니다.

ARF(Adversarial Random Forest)는 두 가지 알고리즘인 FORDE(FORests for Density Estimation)와 FORGE(FORests for GEnerative modeling)의 기반이 됩니다. [위에서 말했듯이 ARF알고리즘을 통해 분류기 생성 → 그 분류기를 이용해 density 추정(FORDE) → density를 통해 각 변수별 데이터 생성(FORGE)]

이 두 알고리즘에서 local independence criterion을 활용해 각 리프 노드 내에서 다변량 밀도 추정 대신 d개의 개별 단변량 밀도 추정기를 실행합니다. 이는 고차원 데이터에서 발생하는 차원의 저주를 피할 수 있기 때문에 훨씬 효율적입니다. 실제로 고차원 데이터에서 전통적인 커널 밀도 추정(KDE)은 차원의 제약으로 인해 잘 작동하지 않지만, ARF는 독립성을 확보하는 분할을 학습함으로써 밀도 추정과 데이터 생성에 더 효과적이고 효율적으로 대응할 수 있습니다.

  • 여기서 차원의 저주를 피할 수 있다고 말하는 이유는, ARF가 각 변수들 간의 독립성을 확보함으로써 각 변수에 대해 개별적으로 밀도 추정을 할 수 있기 때문입니다. 일반적으로 다변량 밀도 추정은 고차원 데이터에서 모든 특성 간의 결합 분포를 학습해야 하기 때문에, 차원의 저주(curse of dimensionality)라는 문제에 직면하게 됩니다. 이는 데이터의 차원이 높아질수록 추정해야 할 분포의 복잡성이 기하급수적으로 증가하기 때문입니다. 반면에, ARF의 경우 local independence criterion(각 변수들 간 독립!)에 따라 리프 노드 내에서 각 특성들이 독립적으로 분포하므로, 각 변수에 대해 단변량 밀도 추정(univariate density estimation)을 수행할 수 있습니다. 이를 통해 다변량 밀도 추정에서 발생하는 차원의 저주 문제를 효과적으로 피할 수 있게 됩니다.

ARF를 기반으로 FORDE와 FORGE 알고리즘이 다음과 같이 진행됩니다.

FORDE

  • 각 트리 b에 대해 분할 기준 $\theta^\ell_b$와 각 리프 노드 $\ell$의 경험적 커버리지 $q(\theta^\ell_b)$를 기록합니다. 이것들을 리프 노드의 파라미터라고 부릅니다.
  • 그런 다음 각 리프 노드에 대해, 원래 데이터의 각 특성 $X_j$에 대해 독립적으로 분포 파라미터 $\psi^\ell_{b,j}$를 추정합니다.
    • 예를 들어, 연속형 데이터의 경우 커널 밀도 추정(KDE)의 대역폭이나 MLE를 활용해 추정합니다.

각 변수에 대한 파라미터 $\psi^\ell_{b,j}$학습

  • 연속형 데이터의 경우 MLE를 사용하여 truncated Gaussian mixture model을 구현하고, 범주형 변수에 대해서는 베이지안 추론을 사용합니다. 이는 리프 노드에서 관찰되지 않은 값에 대해 극단적인 확률을 피하면서 리프 노드의 지원(support) 내에 있는 값을 반영하기 위함입니다.

FORGE

  • 트리 선택: 전체 트리 집합 B에서 트리 b를 균등하게 무작위로 선택하고, 해당 트리에서 리프 노드 $\ell$)를 커버리지 확률 $q(\theta^\ell_b)$ 에 따라 선택합니다. 이는 ARF 알고리즘의 재귀적 반복에서 합성 데이터를 생성하는 방식과 동일합니다.
  • 특성 샘플링: 각 특성 $X_j$에 대해, $\psi^\ell_{b,j}$에 의해 매개변수화된 밀도 또는 질량 함수에 따라 데이터를 샘플링합니다.

이렇게 데이터 생성을 했으니까 데이터 생성을 잘 했는지 확인할 필요가 있죠! 그러기 위해서 추정된 밀도와 실제 밀도를 비교합니다!

  • 추정된 밀도 함수 $q(x)$는 다음과 같이 표현됩니다 $q(x)=\frac{1}{B}\sum\limits_{\ell,b:x\in X^\ell_b}q(\theta^\ell_b)\prod\limits_{j=1}^dq(x_j;\psi^\ell_{b,j}).$
    • 여기서 B는 전체 트리의 수를 나타내고, 분포는 해당 리프 노드에 대한 커버리지 $q(\theta^\ell_b)$에 가중치가 부여된 모든 리프 노드의 평균값으로 나타납니다.
    • 이때 $q(x_j;\psi^\ell_{b,j})$는 각 특성 $x_j$에 대한 밀도 함수입니다.($\psi^\ell_{b,j}$를 distribution의 parameter로 설정)
  • **실제 밀도 함수 $p(x)$**는 다음과 같습니다
    • 이 경우 역시 커버리지 확률$p(\theta^\ell_b)$로 가중치가 부여된 각 리프 노드의 밀도로 구성되어 있습니다.
  • $p(x)=\frac{1}{B}\sum\limits_{\ell,b:x\in X^\ell_b}p(\theta^\ell_b)p(x|\theta^\ell_b).$

밀도 함수의 의미

  • 추정된 밀도 $q(x)$와 실제 밀도 $p(x)$ 모두, 각 리프 노드에서의 분포를 가중치로 합산한 것으로 표현될 수 있습니다.

이 부분에서는 ARF(Adversarial Random Forest) 알고리즘의 손실 함수, 추가 가정, 그리고 발생할 수 있는 세 가지 오류에 대해 설명하고 있습니다.

손실 함수 (Loss Function)

  • L2-consistency에 관심이 있기 때문에, 손실 함수는 평균 통합 제곱 오차(MISE, Mean Integrated Squared Error)로 정의됩니다. $MISE(p,q):=\mathbb{E}[\int_X(p(x)-q(x))^2dx].$
  • 이는 밀도 추정의 정확성을 측정하는 지표로, $p(x)$와 $q(x)$가 얼마나 가까운지를 평가합니다.

추가 가정

  • A4 : 실제 밀도 함수 p는 매끄럽다는 가정을 추가로 요구합니다.
    • 두 번째 도함수 $p^{''}$가 유한하고, 연속적이며, square integrable하고, monotone해야 합니다.
    • 이 가정은 커널 밀도 추정(KDE, Kernel Density Estimation)의 일관성을 보장하기 위한 표준 조건으로, ARF 분석에서 중요한 역할을 합니다.

ARF 방법론은 세 가지 잠재적인 오류를 가질 수 있습니다.

  1. Error of Coverage
    • 정의: 리프 노드 ‘와 트리 b에 대해, 실제 커버리지 $p(\theta^\ell_b)$와 추정된 커버리지 $q(\theta^\ell_b)$의 차이입니다.
    • 의미: 리프 노드의 커버리지가 실제 데이터와 얼마나 일치하는지를 나타내는 오류입니다.
  2. $\epsilon_1:=\epsilon_1(\ell,b):=p(\theta^\ell_b)-q(\theta^\ell_b)$
  3. Error of Density $\epsilon_2:=\epsilon_2(\ell,b,x):=\prod\limits_{j=1}^d p(x_j|\theta^\ell_b)-\prod\limits_{j=1}^dq(x_j;\psi^\ell_{b,j})$
    • 정의: 리프 노드 $\ell$, 트리 b, 샘플 x에 대해, 실제 조건부 밀도 $\prod\limits_{j=1}^d p(x_j|\theta^\ell_b)$와 추정된 조건부 밀도 $\prod\limits_{j=1}^dq(x_j;\psi^\ell_{b,j})$의 차이입니다.
    • 의미: 각 리프 노드 내에서 특성별로 독립적으로 추정된 밀도가 실제 밀도와 얼마나 다른지를 나타내는 오류입니다.
  4. Error of Convergence $\epsilon_3:=\epsilon_3(\ell,b,x):=p(x|\theta^\ell_b)-\prod\limits_{j=1}^d p(x_j|\theta^\ell_b)$
    • 정의: 리프 노드 $\ell$, 트리 b, 샘플 x에 대해, 실제 조건부 밀도 $p(x|\theta^\ell_b)$와 특성별 조건부 밀도의 곱 $\prod\limits_{j=1}^d p(x_j|\theta^\ell_b)$의 차이입니다.
    • 의미: 실제 조건부 밀도가 각 특성의 독립적 밀도의 곱과 얼마나 다른지를 나타내는 오류입니다. 이는 실제 데이터가 특성 간에 종속성을 가질 경우 발생할 수 있습니다.
  • $\epsilon_1$ : 리프 노드 $\ell$와 트리 b에 따라 달라지는 랜덤 변수입니다.
  • $\epsilon_{2,3}$ 리프 노드 $\ell$, 트리 b, 그리고 샘플 x에 따라 달라지는 랜덤 변수입니다.

이러한 가정들과 오류 정의를 통해 ARF는 consistency을 보장하고, 효과적인 데이터 생성 및 분포 학습을 수행할 수 있습니다.

 

 

Experiments

 

  • 위 그림은 ARF 모델을 사용하여 데이터 생성하였습니다. ARF모델이 생각보다 데이터 생성을 잘 하고 있는 것을 볼 수 있습니다.

  • 또한 ARF를 이용해 데이터 생성한 모델과 다른 딥러닝 모델의 성능을 비교해 보았습니다. 어떤 경우에는 FORGE가 성능이 높고 다른 데이터셋에선 다른 딥러닝의 모델이 성능이 높지만, ARF의 모델이 성능이 뒤쳐지지 않을뿐 아니라 가장 빠르게 학습을 할 수 있다는 것입니다. 시간측면에서는 월등히 앞서는 것을 볼 수 있습니다.

다음 글에서 ARF, FORDE, FORGE 알고리즘의 pseudo code를 깃허브에 제공되어있는 R코드와 함께 살펴보면서 각 단계별로 모델이 어떤 흐름으로 진행되는지 확인해보도록 하겠습니다.

 

감사합니다.

 

이번 글에서는 latent space에서 Minority class를 oversampling하여 불균형 문제를 해결할 수 있는 논문을 다뤄보겠습니다.

Abstract

이 논문은 데이터 불균형의 문제를 해결하기 위해 Minority class의 데이터를 latent space에서 Oversampling하여 불균형 문제를 해결하고자 합니다. 이때 Minority class의 데이터는 RAE(Regularized Auto Encoder)를 통해 latent space로 보내지게 됩니다.

latent space에서 Oversampling을 수행하기 때문에 좋은 latent space(데이터를 잘 표현해주는 잠재 공간)를 학습해야 합니다. 그러기 위해서 저자는 Auto-Encoder의 구조를 사용해 조건부 데이터 우도(conditional data likelihood)를 최대화함으로써 latent space를 효과적으로 학습합니다. (특히, latent 샘플들의 convex combination을 사용해 새로운 데이터를 생성하면서도 동일한 클래스의 identity를 보존해야 한다고 합니다. 즉, convex combination을 통해 생성된 데이터가 같은 클래스를 가져야 한다는 뜻 )

또한 저자는 SMOTE와 같은 naive한 oversampling방법과 비교하여 low variance risk estimate을 달성했다고 합니다. 결과적으로 이 방법은 Minority class를 oversampling하는데 효과적이며 불균형 데이터를 해결하는데 도움을 준다고 합니다.

이제부터 차근차근 하나씩 살펴보도록 하겠습니다.

Proposed Method

Class Preserving Oversampling

여기서는 latent space의 분포인 q(z)가 클래스 보존 방식으로 학습되는 것이 핵심입니다.

이때 latent space에서 convex combination을 통해 새로운 샘플을 생성하고자 합니다.

예를 들어 $z_i, i=1,2,...,t$ 이때 $z_i$는 latent vector라고 하고 그에 상응하는 데이터포인트 $x_i, i=1,2,...,t$가 모두 같은 클래스를 갖고있다고 하면 이때 새로운 latent point $z'$은 다음과 같이 얻어집니다.

$z'=\Sigma^t_{i=1}\alpha_i z_i$ $s.t. \ \ \Sigma_i \alpha_i=1$ and $\alpha_i \geq 0$

하지만, 단순히 이렇게 convex combination을 통해 생성된 새로운 데이터는 클래스 보존을 하지 못할수도 있습니다. (각 클래스의 latent space가 겹치는 영역이 존재하게 되면 convex combination을 통해서 생성된 데이터가 클래스 보존을 못할수도 있다는 얘기인 듯 합니다.)

 

그래서 클래스 보존에 대한 문제를 해결하기 위해서는 두가지를 말하고 있습니다.

  1. class-conditional latent space를 학습해야 합니다. 이때 각 클래스별로 class-conditional latent density가 겹치지 않아야 한다고 합니다. 즉, 각 클래스마다 서로 다른 영역을 가진 latent space를 가져야 합니다.
  2. 각 클래스들은 latent space에서 linearly separable해야 됩니다.

각 클래스 i에 속하는 샘플들의 집합 $R_i=\{x|h(z)=i\}$는, linear classifier h(z)에 의해 클래스 레이블 i를 갖는 샘플들입니다. 이때 $q(z|x\in R_i)$는 클래스 i에 대한 latent 분포를 의미합니다. 그리고 두 클래스 i, j에 대한 latent 분포의 support가 겹치지 않도록 학습해야 합니다.

즉, $Supp(q(z|x\in R_i)) \cap Supp(q(z|x\in R_j)) =\varnothing$식을 만족해야 합니다.latent space 학습과 linear classifier 학습을 함께 진행하여 oversampling된 벡터도 원래 클래스에 속하도록 만드는 방식을 다루고 있습니다.

위 두가지인 latent space 학습 linear classifier 학습을 함께 진행하여 oversampling된 벡터도 원래 클래스에 속하도록 만드는 방식을 다루고 있습니다.

여기서 linear classifier는 latent space에서 벡터들이 같은 클래스에 속하는지 확인하는 역할입니다. 예를 들어, 클래스 S1에 속하는 여러 샘플의 latent 벡터를 결합했을 때, 새로운 벡터도 클래스 S1로 분류되도록 linear classifier가 학습됩니다.(위 사진처럼 linear classifier가 linearly separable하다면 S1의 latent vector들의 convex combination도 S1의 Decision Boundary에 속할 것이라는 겁니다.)

위 사진을 보면 1. class-conditional latent space를 학습하였고(겹치지 않는 latent space), 2. latent space에서 linearly separable함!(저런 상태를 원함)

 

Regularized Autoencoders with class preserving latent space

이 부분은 클래스 정보를 유지하면서 오버샘플링을 수행하기 위한 latent space 학습 방법을 설명합니다. 저자는 degenerate representation을 방지하기 위해 클래스 보존 제약을 적용하여 **조건부 데이터 우도(conditional data likelihood)**를 최대화하면서 latent space를 학습합니다.

이를 위해 Regularized Auto-Encoder 구조를 활용하여 클래스 정보를 보존한 채 latent space를 학습합니다.

 

Degenerate Representation이란?

  • Degenerate representation이란, 학습된 모델이 입력 데이터를 지나치게 단순화하거나 압축하여 중복되거나 정보가 손실된 표현을 만들어내는 상황을 의미합니다. latent space에서 서로 다른 데이터들이 거의 동일한 벡터로 매핑되거나, 각 데이터의 고유한 특성들이 사라져 버리는 것을 말합니다.
  • 만약 단일 latent vector로 매핑이 된다면 (위에서 언급했던Class Preserving Oversampling) latent space가 클래스를 보존할 수는 있겠지만, 다양한 데이터를 생성할 수 없게되어 oversampling의 의미가 없어진다고 볼 수 있습니다.
  • 이 문제를 해결하기 위해서 $p(x|z')$(conditional data-likelihood)를 최대화 하여 latent space를 학습한다고 합니다. 이때 클래스 보존 제약을 추가하여 각 클래스 간의 latent space가 겹치지 않아야 합니다!

Encoder-Decoder 구조

  • Encoder: 입력 데이터(x)를 latent space로 축소합니다. 이때 조건부 잠재 분포(conditional latent distribution)를 학습하여, 입력 데이터가 어떤 클래스에 속하는지 반영한 채로 latent 벡터 z로 변환합니다. 이는 $q_\varphi(z|x)$로 표현됩니다. ( $\varphi$는 parameter)
  • Decoder: 학습된 latent 벡터 $z'$(이 $z'$은 latent vector들 간의 convex combination)를 기반으로, 원래 데이터를 복원하는 네트워크입니다. 이는 $p_\theta(x|z')$로 표현되며, oversampled된 데이터가 복원될 때도 클래스 정보가 잘 유지될 수 있도록 학습됩니다.
  • 조건부 데이터 우도(Conditional Data Likelihood)
    • Decoder 네트워크에서 Conditional Data Likelihood( $p(x|z')$ )를 최대화하여, latent space에서 학습된 데이터가 원본 데이터를 잘 복원할 수 있도록 합니다. 이 과정을 통해 latent space가 데이터를 잘 표현할 수 있는 구조로 학습됩니다.

 

여기서 클래스 보존 제약으로 linear classifier인 $h_w(z)$를 사용합니다. 이 linear classifier는 latent space에서 클래스 간 겹침이 없도록 학습합니다.

 

 

이제 최적화 문제를 해결하는데 클래스 보존을 유지하면서 latent space를 학습하고자 합니다!

$max_{\theta,\varphi,\omega}\ \ \mathbb{E}_{z'}logp_\theta(x|z')$

 

  • conditional data likelihood를 최대화 하여 latent space에서 얻어진 벡터 z가 degenerate representation을 피하면서 원본 데이터를 잘 표현하도록 합니다. 이때 제약조건으로 latent space에서 각 클래스간 겹침이 발생하지 않아야 합니다. ($Supp(q(z|x\in R_i)) \cap Supp(q(z|x\in R_j)) =\varnothing$)
  • conditional data likelihood를 최대화 하여 latent space에서 얻어진 벡터 z가 degenerate representation을 피하면서 원본 데이터를 잘 표현하도록 합니다. 이때 제약조건으로 latent space에서 각 클래스간 겹침이 발생하지 않아야 합니다. ($Supp(q(z|x\in R_i)) \cap Supp(q(z|x\in R_j)) =\varnothing$)

세가지 네트워크인 Encoder, Decoder, Classifier는 동시에 학습됩니다. 학습과정에서 convex combination을 통해 oversampling이 진행되게 됩니다. 학습이 완료되면 oversampling된 데이터의 클래스를 보존할 수 있게 되며 성능이 향상될 것입니다.

 

 

Implementation Details

제안된 모델은 Encoder network($E_\varphi$), Decoder network($D_\theta$), 그리고 latent space는 linear classifier($L_\omega$)로 구성되어있습니다. (이때 linear classifier는 선형적으로 클래스들을 분리할 수 있도록 제약 걸었음)

 

이제 저자가 제안한 모델의 아키텍처를 살펴보겠습니다.

Figure 1: Architecture of the proposed methodology. It is a regularized autoencoder, where the latent space is regularized using a linear classifier to facilitate distance metric free class preserving oversampling of the minority classes. The decoder network maximizes the conditional data likelihood to avoid degeneracy in the latent space.

Mixer Network

위 그림을 보시면 $M_\xi$가 존재하는데 이는 Mixer network라고 불리며 여러 latent vector $z_i$들을 입력으로 받아 oversampling을 위한 mixing coefficient $\alpha$를 생성합니다. Mixer network는 softmax 층을 통해 $\alpha$를 생성합니다.

Mixer network는 linear classifier가 분류하기 어려운 샘플을 생성하도록 학습되며 그러기 위해서는 cross-entropy loss를 최소화하도록 학습됩니다. 즉, 새롭게 생성된 $z'$이 linear classifier에 의해 올바르게 분류되지 못하도록 합니다.

Loss는 oversampled된 샘플 $z'$가 원래 $z_i$들과 다른 레이블을 가지도록 설정됩니다. 즉, linear classifier가 이 oversampled된 샘플을 정확하게 분류하지 못하게 하는 것을 목표로 학습됩니다.

아래는 mixer network의 loss입니다.

  • C : 전체 클래스의 수
  • $y_j^{(k)}$ : 클래스 j에 대한 one-hot 인코딩된 레이블입니다. 데이터가 클래스 j에 속하면 1, 아니면 0
  • $\hat{y}'(k)=L_\omega(z')^{(k)}$: linear classifier $L_{\omega}$가 oversampled된 latent vector $z'$에 대해 예측한 클래스 j의 확률입니다.

이 loss function의 목적은 mixer network가 만든 oversampled된 latent vector $z'$이 linear classifier $L_{\omega}$에 의해 잘못 분류되도록 유도하는 것입니다.

 

cross entropy loss를 최소화한다는 것에 직관적으로 이해가 되지 않을 수 있어 예를 들어 설명하겠습니다.

Class : 개, 고양이, 토끼 이렇게 3개의 클래스가 있다고 하겠습니다.

 

원래 클래스는 개라고 설정하겠습니다.(i=개) 그리고 개의 latent vector들의 convex combination으로 새롭게 생성된 $z'$은 개이지만, 임의로 고양이로 설정하겠습니다.(j=고양이)

그러면 분류기는 $z'$에 대해 당연히 개라고 분류할 확률이 높겠죠? (개의 latent vector들로 convex combination을 수행했으므로!)

 

분류기가 $z'$을 개라고 분류할 확률 : 0.7

분류기가 $z'$을 고양이라고 분류할 확률 : 0.1

분류기가 $z'$을 토끼라고 분류할 확률 : 0.2 라고 하겠습니다.

$L_{mixer}$ = -( 0 * log0.7 + 1 * log0.1 + 0 * log0.2) = -log0.1 ~ 2.302585

loss는 2.302585입니다. 이때 우리는 cross entropy가 최소화되도록 하는겁니다. 그러면 고양이로 분류할 확률이 0.1이었지만 이 확률을 높이면 cross entropy가 작아지겠죠?

 

이제는 업데이트가 되어서

분류기가 $z'$을 개라고 분류할 확률 : 0.6

분류기가 $z'$을 고양이라고 분류할 확률 : 0.2

분류기가 $z'$을 토끼라고 분류할 확률 : 0.2 라고 하겠습니다.

$L_{mixer}$ = -( 0 * log0.6 + 1 * log0.2 + 0 * log0.2) = -log0.1 ~ 1.609438

 

계속 업데이트가 되어서

분류기가 $z'$을 개라고 분류할 확률 : 0.45

분류기가 $z'$을 고양이라고 분류할 확률 : 0.45

분류기가 $z'$을 토끼라고 분류할 확률 : 0.1 라고 하겠습니다.

$L_{mixer}$ = -( 0 * log0.45 + 1 * log0.45 + 0 * log0.1) = -log0.45 ~ 0.7985077

 

loss는 0.7985가 되었습니다. 즉, cross entropy를 최소화 하려면 분류기가 고양이의 확률을 높이는 것입니다. 결국 $z'$에 대해서 분류기는 개와 고양이 중에 분류하기 어려운 샘플이 만들어 지겠죠? 그것이 이 Mixer network가 원하는 것입니다!(분류기가 분류하는데 어려움을 겪게되는 샘플을 생성하는 것)

 

$z'$은 강아지 label을 갖지만, 고양이와 강아지를 확실하게 분류하기 어려운 샘플을 만드는 것이지요!

 

이를 통해, Mixer 네트워크는 challenging samples를 생성하여 모델의 robustness를 향상시키고, latent space에서 클래스 보존과 다양성을 동시에 유지하게 됩니다. (Mixer network의 역할은 $\alpha$를 다양하게 설정하여 동일 클래스 내에서 다양한 샘플을 생성하는 것입니다. 논문에서 각 클래스 간의 latent space가 겹치지 않도록 설계되었기 때문에, $\alpha$를 통해 어려운 샘플을 만들어도 클래스는 변하지 않습니다. 따라서, Mixer network는 클래스 내에서의 데이터 다양성을 높이기 위한 과정으로 이해할 수 있을 것 같습니다!)

 

이제 Mixer network를 살펴보았으니 Encoder, Decoder, linear classifier부분의 loss를 살펴보겠습니다.

Decoder

Decoder부분에서는 $p_\theta(x|z')$를 최대화 합니다. 저자는 adversarial loss를 사용하여 decoder가 생성한 데이터의 분포와 기존의 데이터의 분포를 일치시키기 위함입니다. 여기서 WGAN-GP를 사용하여 안정적이고 수렴이 잘 되도록 합니다. 또한 Critic network를 도입하는데, 이는 생성된 샘플과 실제 데이터 샘플 간의 차이를 줄이고, 그래디언트 패널티를 통해 안정적인 학습을 보장합니다!

(아마 여기서 WGAN-GP를 사용한 이유는 만약 분포끼리의 겹침이 없으면 Kullback-Leibler Divergence & Jensen-Shannon Divergence는 안정적으로 학습이 어려워 wasserstein distance를 사용해 분포간 거리를 일치시키는 것 같습니다.)

  • Decoder 손실함수

  • 의미: 디코더는 Critic 네트워크의 출력을 최대화하여, 생성된 샘플이 실제 데이터와 유사한 분포를 갖도록 유도합니다.
  • Critic 손실함수

  • 의미: Critic은 생성된 샘플과 실제 데이터 샘플 간의 차이를 줄이고, 그래디언트 패널티를 통해 안정적인 학습을 보장합니다.
    • 디코더 단독으로는 생성된 샘플이 실제 데이터 분포와 얼마나 유사한지를 직접적으로 측정하기 어렵습니다. 단순한 재구성 손실(reconstruction loss)이나 픽셀 단위의 손실은 샘플의 질적 유사성을 제대로 반영하지 못할 수 있습니다.
    • Critic network는 생성된 샘플과 실제 샘플 간의 분포적 차이를 측정하여, Decoder가 보다 정교하게 학습할 수 있도록 피드백을 제공합니다.

그리고 여기서 sample-by-sample 간의 매치가 아니라 하는데, 그 이유는 생성된 데이터는 기존의 데이터에 존재하지 않기때문에 샘플들 간의 매치를 보는것이 아닌 생성된 데이터의 분포와 기존의 데이터의 분포 차이를 보는거라고 하고 있습니다.

 

Linear Classifier

다음은 linear Classifier에 대해서 설명하겠습니다.

여기서는 categorical crossentropy loss terms을 사용하게 되는데, 이때 사용되는 데이터는 (i=개의 개수와 새롭게 생성된 $z'$까지 총 t+1개의 데이터에 대해서 loss를 구하게 됩니다.)

아래는 classifier에 대한 loss function입니다.

여기서 $z'$ 외에는 앞에 텀에서 계산이 되고, 뒤에 텀에서 $z'$에 대한 loss가 계산이 되는데, 이번에는 분류기가 $z'$에 대해서 강아지로 분류할 수 있도록 업데이트 됩니다.(앞에서는 일부러 $z'$을 고양이로 설정해서 $\alpha$를 만들었지만, 얘는 결국엔 강아지에 대한 데이터입니다. 그렇기 때문에 여기 linear classifier에서는 $z'$에 대해서 강아지라고 제대로 분류할 수 있도록 손실함수를 업데이트 해야합니다!)

 

Encoder

Encoder network에서 input data를 latent space로 보낼때 중요한 정보를 담고잇도록 해야합니다.

아래는 encoder에 대한 loss function입니다.

 

encoder손실에는 분류 손실과, 평균 절대 오차 손실로 구성되어있습니다.

분류 손실을 학습하여 latent space에서의 vector들이 중요한 정보를 담을 수 있도록 합니다.

평균 절대 오차 손실에서는 input data랑 input data를 Encoder를 통해 잠재벡터로 변환 후 다시 디코더를 통해 복원한 데이터 간의 차이를 확인해 재구성이 잘 되도록 합니다. (또한 재구성이 잘 됐다는건, latent space에서 latent vector가 중요한 정보를 담고있었다고 해석할 수 있습니다.)

 

Experimental Results

제안한 모델이 다른 모델에 비해 성능이 얼마나 좋은지 확인해보도록 하겠습니다. 다양한 Vision data를 사용하고, 성능평가 지표로는 ACSA, F1 score, GM 을 사용하였습니다.

저자가 제안한 모델이 확실히 성능이 좋은 것을 알 수 있습니다!

 

 

또한 데이터의 질적 및 다양성을 측정하기 위한 지표로 Density와 Coverage를 사용했습니다. 이 지표들은 생성된 샘플이 실제 데이터 분포를 얼마나 잘 반영하고 있는지, 그리고 얼마나 다양한 샘플을 생성하고 있는지를 평가하는 데 유용합니다.

  • 높은 Density 값은 생성된 샘플이 실제 데이터의 특정 영역에 집중적으로 분포되어 있음을 의미하며, 이는 고품질의 샘플이 생성되고 있음을 시사합니다.
  • 높은 Coverage 값은 생성된 샘플이 실제 데이터의 다양한 특성을 잘 반영하고 있음을 의미하며, 이는 다양한 샘플이 생성되고 있음을 시사합니다.

확실히 제안된 모델이 전체적으로 다양한 데이터를 생성하고 있으며 실제 데이터 영역에 집중적으로 분포되어 있음을 알 수 있습니다!

 

 

그리고 또한 Ablation studies를 통해 어떤 process에서 성능이 효과적이었던지도 파악할 수 있습니다.

 

Mixer Network를 사용한 모델과 사용하지 않은 모델에 대해서는 그다지 큰 차이가 나타나지 않았네요. (되게 신박한 아이디어라고 생각했는데, 성능 측면에서는 엄청난 효과를 불러일으키진 않았네요)

 

 

 

마지막으로 위 사진은 latent space에서 각 클래스의 분포를 나타낸 것인데, MNIST의 데이터 경우에는 확연하게 분리가 되어져 있는 것을 볼 수 있습니다.(degenerate representation을 방지하기 위해 class-preserving regularization을 추가해 각 클래스들간 latent space에서 겹침이 없어야 한다는 조건을 어느정도 만족시킨 것 같습니다.) (다른 data의 경우에는 linearly separable하진 않게 나오긴 했네요. 완벽하게 linearly separable 하다면 엄청난 모델이 될 것 같습니다!)

Conclusion

Minority class의 데이터를 oversampling하는데 클래스 보존의 제약을 걸어주고 다양한 데이터를 생성할 수 있게 하였습니다. Imbalanced data의 상황에서 이 모델은 아주 강력한 무기가 될 것 같습니다. 이런 concept을 vision data 뿐만 아니라 Tabular data에 대해서도 적용할 수 있으면 좋은 모델이 생길 수 있지 않을까 생각해 봅니다. Tabular data에 한번 적용시킬 수 있는 Idea를 고려해봐도 좋은 작업이 될 것 같습니다.

감사합니다!

 

참고문헌

  1. 논문 : https://proceedings.mlr.press/v206/mondal23a.html
  2. [WGAN-GP] https://arxiv.org/pdf/1704.00028
 

오늘은 2019년 NeurIPS에서 발표된 CTGAN 논문을 리뷰해 보겠습니다.

 

Abstract

Tabular data란 Discrete(이산형) columns, Continuous(연속형) columns을 갖고있는 데이터입니다.

Continuous columns은 multiple modes(여러개의 봉우리)를 가지고 있으며 Discrete columns은 각 카테고리 수가 불균형(암 환자 : 5%, 정상 환자 : 95%)하게 되어있으면 Deep neural network 모델은 모델링 하는데 어려움을 겪습니다.


저자는 CTGAN이라는 모델을 제안했으며 이 모델은 위에서 제시한 문제점을 해결하기 위해 Conditional Generator를 사용한다고 합니다. 모델이 어떤 구조인지 살펴보도록 하겠습니다.

 

Challenges with GANs in Tabular Data Generation Task

Table $T$(Tabular data)에 대한 column 정의는 다음과 같습니다.

Continuous columns : $\{C_1,...,C_{N_C}\}$

Discrete columns : $\{D_1,...,D_{N_d}\}$

Total columns : 총 $N_C$+$N_d$ 즉, N개의 컬럼을 갖고 있다고 보면 됩니다.

 

각각의 컬럼은 Random variable(확률변수)로 생각하고 각 컬럼은 unknown joint distribution $\mathbb{P}(C_{1:N_c},D_{1:N_d})$라고 합니다. 즉 각 컬럼 간 독립이 아니라 컬럼끼리 관계를 갖고 있다고 해석 할 수 있을 것 같네요

One row $\bold{r}j=\{c{1,j},...,c_{N_c,j},d_{1,j},...,d_{N_d,j}\},j\in\{1,...,n\}$

 

Table $T$를 $T_{train}$과 $T_{test}$로 나누고, $T_{train}$으로 G를 학습한 후에 G를 사용해 각 행을 독립적으로 생성한 집합을 $T_{syn}$이라 합니다.

 

저자는 2가지 측면으로 Generator의 효율성을 판단한다고 합니다!

  1. Likelihood fitness : $T_{syn}$로 생성한 컬럼이 $T_{train}$와 같은 joint distribution 따르는지
  2. Machine learning efficacy : $T_{train}$로 모델을 학습하여 $T_{test}$ 평가한 성능과 $T_{syn}$로 모델을 학습하여 $T_{test}$ 평가한 성능이 비슷한지?

기존의 GAN 모델을 사용해 Tabular data를 생성하면 문제가 있다고 했는데, 어떤 문제가 있는지 확인해 보겠습니다. 총 5가지 도전과제가 있습니다.

  1. Mixed data types : 연속형, 이산형 columns을 동시에 생성하기 위해서는 Softmax, Tanh 함수를 적용해야 합니다. Softmax는 이산형 columns, Tanh는 연속형 columns를 처리하기 위한 함수입니다.

    GAN의 경우 Mixed data를 처리할 때 최적화 문제가 발생한다고 합니다! GAN을 학습시킬 때, 이산형 데이터와 연속형 데이터를 동시에 다루기 위해 손실 함수를 적절히 조합하고 최적화해야 합니다. 이산형 데이터를 위한 손실은 일반적으로 분류 문제에 사용되는 크로스 엔트로피 같은 함수가 적합하고, 연속형 데이터에는 평균 제곱 오차와 같은 함수가 적합합니다. 이 두 손실 함수를 적절히 조합하는 것은 도전적이라고 합니다.

  2. Non-Gaussian distributions : 이미지 데이터의 경우 pixel들은 Gaussian과 유사한 분포를 따르기 때문에 [-1,1]의 범위로 normalizing할 수 있지만, Continuous 컬럼의 값들은 Gaussian 분포를 따르지 않아서 Tanh로 normalizing 시키면 vanishing gradient problem문제가 발생합니다.

    왜 gradient problem문제가 발생하느냐? → 만약 3개의 mode를 가진다고 가정해봅시다. 각각의 mode는 $N(0,1)$, $N(5,1)$, $N(10,1)$의 분포를 갖는 데이터들이 존재한다고 하면 이 값들을 Tanh로 normalizing 시켜버리면 $N(0,1)$과 $N(10,1)$ 분포에 존재하는 데이터 들은 각각 -1과 1에 근사하는 값을 가집니다.

    근데 Tanh함수는 아래와 같은 그림을 나타냅니다. 여기서 -1과 1에 근사하는 값의 기울기는 거의 0이 되겠죠. 그래서 Backpropagation과정에서 기울기가 소실된다고 말하는 것 같습니다. (기울기 소실이 발생하면, 네트워크의 특정 부분에서 가중치가 업데이트 되지 않거나 매우 느리게 학습되어, 전체적인 학습 과정의 효율성과 효과가 크게 저하됩니다.)
    Tanh 함수
  3. Multimodal distributions : 여러개의 mode(봉우리)를 가지고 있어서 Kernel Density Estimation(KDE)로 mode를 추정합니다. 기존의 GAN은 이런 Multimodal distribution을 모델링하는데 어려움을 겪는다고 합니다.

    그러면 GAN 모델은 Multimodal distribution을 모델링하는데 왜 어려움을 겪을까요? 모드 간 균형: GAN은 경향적으로 분포의 주된 모드에 초점을 맞추고, 덜 대표적인 모드는 무시할 수 있습니다. 이로 인해 데이터의 다양성과 복잡성을 완전히 포착하지 못할 수 있습니다.

  4. Learning from sparse one-hot-encoded vectors : 새로운 샘플들을 생성하면 모델은 softmax를 사용하여 각 카테고리의 확률(e.g. [0.7,0.2,0.1])을 출력합니다. 하지만, 실제 데이터는 one-hot vector(e.g. [1,0,0])로 표현됩니다.


    이게 무슨 문제가 되냐? 실제 데이터는 원-핫 벡터로 매우 희소한(0이 많음) 반면, 생성된 데이터는 확률 분포로 인해 상대적으로 덜 희소합니다. 이러한 차이는 판별자(discriminator)가 실제 데이터와 생성된 데이터를 구별하는 데 사용될 수 있습니다.

    위의 예시의 경우 one-hot vector에서 1의 카테고리가 생성된 모델에서는 0.7의 확률을 가지니 ‘진짜’라고 판단을 해야하는데, 이런 특성을 보지 않고 그저 벡터의 희소성만 확인하여 생성된 데이터를 가짜라고 구별하게 됩니다. 이렇게 되면 [0.99,0.005,0.005]의 데이터여도 희소하지 않기 때문에 가짜라고 구별하게 돼서 GAN 학습 과정에서 문제가 발생할 수 있는 것 같습니다.

    D입장에서는 [0.99,0.005,0.005]도 가짜라 생각해 0을 출력하니 Maximize가 되지만, G입장에서 보면 E_{z\sim p(z)}[log(1-D(G(z)))]값이 E_{z\sim p(z)}[log(1-0)]=0이 되어 Minimize가 되지 않습니다. 즉, 제대로 학습이 안되는거죠.

  5. Highly imbalanced categorical columns : 이산형 컬럼에서 category의 빈도가 불균형 하여 mode collapse(모드 붕괴)가 일어납니다. 즉, 생성자가 데이터의 다양성을 반영하지 못하고 주로 주요 카테고리만을 반복적으로 생성하는 현상입니다. 이는 GAN이 다양한 데이터 패턴을 학습하고 재현하는 데 실패하게 만듭니다.

    그리고 minor category를 생성과정에서 누락해도 데이터의 분포의 변화는 거의 없어서 판별자가 이것을 감지하기는 어렵다고 합니다! 그렇게 되면 minor category 데이터는 거의 생성이 되지 않는 문제가 발생하여 전체 데이터의 다양성을 학습하지 못하고 major category의 데이터만 생성하게 되는 문제가 발생하는 것 같습니다.

위 5가지 문제를 해결하면 2가지 평가지표가 좋은 성능을 보일 것 같습니다!

 

 

CTGAN Model

저자는 위 문제를 다음과 같이 해결하였습니다.

Mode-specific Normalization : Non-Gaussian distributions, Multimodal distribution(2,3) 해결(Continuous columns)

Conditional Generator and Training-by-Sampling : Imbalanced(4,5) 문제 해결 (Discrete columns)

 

 

Mode-specific Normalization

Non-Gaussian, Multimodal 문제였던 것을 확인해 봅시다. $T$에서 Continuous 컬럼은 아래 그림처럼 여러개의 mode를 갖는다고 했습니다.

위 그림은 3개의 mode가 존재하니 3개의 sub distribution으로 나누고, 각각의 distribution에서의 평균 : $\eta_k$, Weight : $\mu_k$, standard deviation : $\phi_k$로 설정합니다.

$c_{i,j}$ : i 번째 컬럼에 해당하는 j번째 행 데이터로 각각 Gaussian mode에서 발생할 확률은 $\rho_k$입니다.

그림에선 $\rho_3$ 확률이 가장 높으므로($c_{i,j}$는 $\rho_3$의 분포에서 나왔을 것!) $\beta_{i,j}=[0,0,1]$로 표현하고 $\alpha_{i,j}$도 식에 대입합니다.

기존의 One row $\bold{r}_j$를 다시 재표현 하면 아래와 같습니다.

여기서 $d_{1,j}$ 이전의 부분들은 continuous columns이고, mode-specific Normalization을 통해 구할 수 있습니다.

$d_{1,j}$ 이후의 부분들은 discrete columns로 구성되었으며 one-hot encoding 되어있습니다.

 

 

Conditional Generator and Training-by-Sampling

기존의 GAN에서는 minor category를 고려하지 못하는 문제가 있었기 때문에 Conditional Generator를 도입합니다. Conditional Generator를 통해 특정 이산형 컬럼의 값에 따라 데이터 행을 생성할 수 있도록 하는 것입니다. CGAN과 유사한 방식이라고 생각하시면 됩니다! 그러면 이제 어떻게 Condition을 줄 것이냐? 바로 Training-by-Sampling 방법으로 Condition을 줄 것입니다! ( 이 방식으로 불균형 한 문제를 해결할 수 있습니다.)

 

아래는 Training-by-Sampling 방식입니다!

각각 단계를 자세히 보겠습니다.

  1. $N_d$ 개의 discrete columns 중에 랜덤으로 한개의 컬럼을 선택 $i^*$
  2. 선택된 Discrete columns에 대해 PMF(확률질량함수) 구함
  3. PMF를 따르는 확률 분포에 따라 하나의 카테고리를 선택한다. 이를 $k^$라고 합니다. 아래는 $k^$ 선택하는 과정입니다.(카테고리가 2개가 있다고 가정)
    1. 첫 번째 카테고리 빈도 : 100
    2. 두 번째 카테고리 빈도 : 10
      1. 첫번째 카테고리 확률 : 100/110 ~ 0.9
      2. 두번째 카테고리 확률 : 10/110 ~ 0.1
    3. 그대로 사용하는게 아니라 log 변환 수행
      1. log(100) ~ 4.61, log(10) ~ 2.3
      2. 첫번째 카테고리 확률 : 4.61/6.91 ~ 0.667
      3. 두번째 카테고리 확률 : 2.3/6.91 ~ 0.333
    4. log의 유무에 따라 뽑힐 확률이 달라지는게 보이시죠! log변환을 통해 minor category가 뽑힐 수 있는 확률을 늘려주었습니다!
  4. Conditional vector 생성
    만약 2개의 이산형 컬럼이 존재하고, 첫번째 컬럼에는 3개의 카테고리, 두번째 컬럼에는 2개의 카테고리가 있는데, 2번째 컬럼의 첫번째 카테고리가 선택되었다면 Conditional vector는 다음과 같습니다. Conditional vector : [0,0,0,1,0]

 

그리고 여기서 저희가 Generator loss를 추가해 주는데, conditional vector로 2번째 컬럼의 1번째 카테고리가 주어졌을 때 이 조건에 맞는 이산형 벡터를 생성해야하는데, 잘못된 벡터가 생성되었을 수도 있으니 그 손실을 감소시키기 위해 cross-entropy 손실을 추가합니다!

 

위 방법으로 기존의 문제였던 5(Highly imbalanced categorical columns)번 문제를 해결할 수 있었습니다.

 

그러면 이제 4번문제가 남아있는데, 기존의 GAN에서 Softmax 함수를 사용했다면, Gumbel-Softmax 함수를 사용해서 sparse한걸로 판별했던 문제를 해결할 수 있었습니다.

Gumbel-Softmax 내용은 Chat GPT를 사용하였습니다

 

Gumbel-Softmax는 각 범주에 대한 확률을 계산한 후, Gumbel 분포를 통해 샘플링하여 one-hot vector와 유사한 출력을 생성할 수 있게 합니다. 이를 통해 신경망은 연속적인 방식으로 역전파를 수행하면서도, 이산적인 범주형 데이터를 효과적으로 생성할 수 있습니다.

기존의 GAN의 도전과제들을 다 해결하였습니다. 이제 두 가지 평가지표로 성능이 확실한지 파악해보겠습니다.

 

Evaluation Metrics and Framework

두가지 평가지표

  1. Likelihood fitness metric
  2. Machine learning efficacy

 

Likelihood fitness metric

과정은 아래 사진과 같습니다.

 

Synthetic Data Generator: 이 생성기는 학습 데이터를 기반으로 합성 데이터를 생성합니다.

합성 데이터(Synthetic Data): Generator에 의해 생성된 데이터입니다. 이 데이터는 실제 데이터와 유사한 데이터 입니다

Likelihood $L_{syn}$: 합성 데이터의 likelihood를 계산하여, 이 데이터가 실제 분포를 얼마나 잘 따르는지 평가합니다.

Likelihood $L_{test}$: 테스트 데이터에 대한 likelihood를 계산하여, 합성데이터가 실제 데이터를 얼마나 잘 모델링하는지 평가합니다.

 

Machine learning efficacy

실제 데이터에서 효율 확인

이번에는 합성데이터를 이용해 Decision Tree, Linear SVM, MLP를 사용하여 학습한 후 Test data에 대해 예측을 수행하여 Accuracy와 F1, $R^2$ 확인!

 

Result

TVAE와 CTGAN에서 우수한 성능을 보이고 있음!

그리고 CTGAN에서는 Generator에서 input data가 아닌 noise를 사용하기 때문에 Privacy 문제에 유용하다고 합니다(TVAE보다)

 

GM Sim, BN Sim에서 Likelihood 값이 커야 두 분포가 유사하다고 볼 수 있습니다. CTGAN이나 TVAE가 역시 다른 모델에 비해 likelihood값이 대략적으로 큰 것을 볼 수 있습니다!

 

 

참고문헌

  1. https://arxiv.org/pdf/1907.00503

 

감사합니다!

저번에 InforGAN에 대해 논문 리뷰를 해보았는데, 오늘은 MNIST 데이터셋에 대해서 InfoGAN에 대해 적용시켜 실제로 특징을 잘 학습하는지 확인해보도록 하겠습니다. (구글 코랩 기준으로 작성)

 

MNIST는 아래와 같은 아키텍쳐로 코드를 구현하셨습니다.

 

Import

from tqdm import tqdm
import time
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dset
import matplotlib.pyplot as plt
import torchvision.utils as vutils
import time
import torch.optim as optim
from PIL import Image
from torch.utils.data import Dataset,DataLoader

# Dictionary storing network parameters. 설정한 파라미터
params = {
    'batch_size': 128,# Batch size.
    'num_epochs': 30,# Number of epochs to train for.
    'learning_rate': 2e-4,# Learning rate.
    'beta1': 0.5,
    'beta2': 0.999,
    'save_epoch' : 25,# After how many epochs to save checkpoints and generate test output.
    'dataset' : 'MNIST'}# Dataset to use. Choose from {MNIST, SVHN, CelebA, FashionMNIST}. CASE MUST MATCH EXACTLY!!!!!

 

데이터셋 불러오기

import torch
import torchvision.transforms as transforms
import torchvision.datasets as dsets

# Directory containing the data.
root = 'data/'

def get_data(dataset, batch_size):

    # Get MNIST dataset.
    if dataset == 'MNIST':
        transform = transforms.Compose([
            transforms.Resize(28),
            transforms.CenterCrop(28),
            transforms.ToTensor()])

        dataset = dsets.MNIST(root+'mnist/', train='train',
                                download=True, transform=transform)


    # Get FashionMNIST dataset.
    elif dataset == 'FashionMNIST':
        transform = transforms.Compose([
            transforms.Resize(28),
            transforms.CenterCrop(28),
            transforms.ToTensor()])

        dataset = dsets.FashionMNIST(root+'fashionmnist/', train='train',
                                download=True, transform=transform)

    # Get CelebA dataset.
    # MUST ALREADY BE DOWNLOADED IN THE APPROPRIATE DIRECTOR DEFINED BY ROOT PATH!

    # Create dataloader.
    dataloader = torch.utils.data.DataLoader(dataset,
                                            batch_size=batch_size,
                                            shuffle=True)

    return dataloader

Model

import torch
import torch.nn as nn
import torch.nn.functional as F

class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.tconv1 = nn.ConvTranspose2d(74, 1024, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(1024)

        self.tconv2 = nn.ConvTranspose2d(1024, 128, 7, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(128)

        self.tconv3 = nn.ConvTranspose2d(128, 64, 4, 2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(64)

        self.tconv4 = nn.ConvTranspose2d(64, 1, 4, 2, padding=1, bias=False)

    def forward(self, x):
        x = F.relu(self.bn1(self.tconv1(x)))
        x = F.relu(self.bn2(self.tconv2(x)))
        x = F.relu(self.bn3(self.tconv3(x)))

        img = torch.sigmoid(self.tconv4(x))

        return img

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 64, 4, 2, 1)

        self.conv2 = nn.Conv2d(64, 128, 4, 2, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(128)

        self.conv3 = nn.Conv2d(128, 1024, 7, bias=False)
        self.bn3 = nn.BatchNorm2d(1024)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.1, inplace=True)
        x = F.leaky_relu(self.bn2(self.conv2(x)), 0.1, inplace=True)
        x = F.leaky_relu(self.bn3(self.conv3(x)), 0.1, inplace=True)

        return x

class DHead(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = nn.Conv2d(1024, 1, 1)

    def forward(self, x):
        output = torch.sigmoid(self.conv(x))

        return output

class QHead(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1024, 128, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(128)

        self.conv_disc = nn.Conv2d(128, 10, 1)
        self.conv_mu = nn.Conv2d(128, 2, 1)
        self.conv_var = nn.Conv2d(128, 2, 1)

    def forward(self, x):
        x = F.leaky_relu(self.bn1(self.conv1(x)), 0.1, inplace=True)

        disc_logits = self.conv_disc(x).squeeze()

        mu = self.conv_mu(x).squeeze()
        var = torch.exp(self.conv_var(x).squeeze())

        return disc_logits, mu, var

가중치, 노이즈, loss function

import torch
import torch.nn as nn
import numpy as np

def weights_init(m):
    """
    Initialise weights of the model.
    """
    if(type(m) == nn.ConvTranspose2d or type(m) == nn.Conv2d):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif(type(m) == nn.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class NormalNLLLoss:
    """
    Calculate the negative log likelihood
    of normal distribution.
    This needs to be minimised.

    Treating Q(cj | x) as a factored Gaussian.
    """
    def __call__(self, x, mu, var):

        logli = -0.5 * (var.mul(2 * np.pi) + 1e-6).log() - (x - mu).pow(2).div(var.mul(2.0) + 1e-6)
        nll = -(logli.sum(1).mean())

        return nll

def noise_sample(n_dis_c, dis_c_dim, n_con_c, n_z, batch_size, device):
    """
    Sample random noise vector for training.

    INPUT
    --------
    n_dis_c : Number of discrete latent code.
    dis_c_dim : Dimension of discrete latent code.
    n_con_c : Number of continuous latent code.
    n_z : Dimension of iicompressible noise.
    batch_size : Batch Size
    device : GPU/CPU
    """

    z = torch.randn(batch_size, n_z, 1, 1, device=device)

    idx = np.zeros((n_dis_c, batch_size))
    if(n_dis_c != 0):
        dis_c = torch.zeros(batch_size, n_dis_c, dis_c_dim, device=device)

        for i in range(n_dis_c):
            idx[i] = np.random.randint(dis_c_dim, size=batch_size)
            dis_c[torch.arange(0, batch_size), i, idx[i]] = 1.0

        dis_c = dis_c.view(batch_size, -1, 1, 1)

    if(n_con_c != 0):
        # Random uniform between -1 and 1.
        con_c = (torch.rand(batch_size, n_con_c, 1, 1, device=device) * 2 - 1)

    noise = z
    if(n_dis_c != 0):
        noise = torch.cat((z, dis_c), dim=1)
    if(n_con_c != 0):
        noise = torch.cat((noise, con_c), dim=1)

    return noise, idx

MNIST의 경우에 digit type을 나타내는 변수 10개와 rotaion, width를 나타내는 변수 2개에 62개의노이즈를 추가해서 총 72개의 z로 시작합니다.

Train

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import time
import random

# Set random seed for reproducibility.
seed = 20240409
random.seed(seed)
torch.manual_seed(seed)
print("Random Seed: ", seed)

# Use GPU if available.
device = torch.device("cuda:0" if(torch.cuda.is_available()) else "cpu")
print(device, " will be used.\n")

dataloader = get_data(params['dataset'], params['batch_size'])

# Set appropriate hyperparameters depending on the dataset used.
# The values given in the InfoGAN paper are used.
# num_z : dimension of incompressible noise.
# num_dis_c : number of discrete latent code used.
# dis_c_dim : dimension of discrete latent code.
# num_con_c : number of continuous latent code used.
# num_z 62 -> 61, num_con_c 2 -> 3 
if(params['dataset'] == 'MNIST'):
    params['num_z'] = 62
    params['num_dis_c'] = 1
    params['dis_c_dim'] = 10
    params['num_con_c'] = 2
elif(params['dataset'] == 'FashionMNIST'):
    params['num_z'] = 62
    params['num_dis_c'] = 1
    params['dis_c_dim'] = 10
    params['num_con_c'] = 2

# Plot the training images.
sample_batch = next(iter(dataloader))
plt.figure(figsize=(10, 10))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(
    sample_batch[0].to(device)[ : 100], nrow=10, padding=2, normalize=True).cpu(), (1, 2, 0)))
plt.savefig('Training Images {}'.format(params['dataset']))
plt.close('all')

# Initialise the network.
netG = Generator().to(device)
netG.apply(weights_init)
print(netG)

discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)

netD = DHead().to(device)
netD.apply(weights_init)
print(netD)

netQ = QHead().to(device)
netQ.apply(weights_init)
print(netQ)

# Loss for discrimination between real and fake images.
criterionD = nn.BCELoss()
# Loss for discrete latent code.
criterionQ_dis = nn.CrossEntropyLoss()
# Loss for continuous latent code.
criterionQ_con = NormalNLLLoss()

# Adam optimiser is used.
optimD = optim.Adam([{'params': discriminator.parameters()}, {'params': netD.parameters()}], lr=params['learning_rate'], betas=(params['beta1'], params['beta2']))
optimG = optim.Adam([{'params': netG.parameters()}, {'params': netQ.parameters()}], lr=params['learning_rate'], betas=(params['beta1'], params['beta2']))

# Fixed Noise
z = torch.randn(100, params['num_z'], 1, 1, device=device)
fixed_noise = z
if(params['num_dis_c'] != 0):
    idx = np.arange(params['dis_c_dim']).repeat(10)
    dis_c = torch.zeros(100, params['num_dis_c'], params['dis_c_dim'], device=device)
    for i in range(params['num_dis_c']):
        dis_c[torch.arange(0, 100), i, idx] = 1.0

    dis_c = dis_c.view(100, -1, 1, 1)

    fixed_noise = torch.cat((fixed_noise, dis_c), dim=1)

if(params['num_con_c'] != 0):
		# 회전, 너비 등을 더 자세히 보기위함
    con_c = (torch.rand(100, params['num_con_c'], 1, 1, device=device) * 2 - 1)
    fixed_noise = torch.cat((fixed_noise, con_c), dim=1)

real_label = 1
fake_label = 0

# List variables to store results pf training.
img_list = []
G_losses = []
D_losses = []

print("-"*25)
print("Starting Training Loop...\n")
print('Epochs: %d\nDataset: {}\nBatch Size: %d\nLength of Data Loader: %d'.format(params['dataset']) % (params['num_epochs'], params['batch_size'], len(dataloader)))
print("-"*25)

start_time = time.time()
iters = 0

for epoch in range(params['num_epochs']):
    epoch_start_time = time.time()

    for i, (data, _) in tqdm(enumerate(dataloader, 0)):
        # Get batch size
        b_size = data.size(0)
        # Transfer data tensor to GPU/CPU (device)
        real_data = data.to(device)

        # Updating discriminator and DHead
        optimD.zero_grad()
        # Real data
        label = torch.full((b_size, ), real_label, device=device)
        # label type을 맞추기 위해 추가
        label=label.to(torch.float32) 
        output1 = discriminator(real_data)
        probs_real = netD(output1).view(-1)
        loss_real = criterionD(probs_real, label)
        # Calculate gradients.
        loss_real.backward()

        # Fake data
        label.fill_(fake_label)
        noise, idx = noise_sample(params['num_dis_c'], params['dis_c_dim'], params['num_con_c'], params['num_z'], b_size, device)
        fake_data = netG(noise)
        output2 = discriminator(fake_data.detach())
        probs_fake = netD(output2).view(-1)
        loss_fake = criterionD(probs_fake, label)
        # Calculate gradients.
        loss_fake.backward()

        # Net Loss for the discriminator
        D_loss = loss_real + loss_fake
        # Update parameters
        optimD.step()

        # Updating Generator and QHead
        optimG.zero_grad()

        # Fake data treated as real.
        output = discriminator(fake_data)
        label.fill_(real_label)
        probs_fake = netD(output).view(-1)
        gen_loss = criterionD(probs_fake, label)

        q_logits, q_mu, q_var = netQ(output)
        target = torch.LongTensor(idx).to(device)
        # Calculating loss for discrete latent code.
        dis_loss = 0
        for j in range(params['num_dis_c']):
            dis_loss += criterionQ_dis(q_logits[:, j*10 : j*10 + 10], target[j])

        # Calculating loss for continuous latent code.
        con_loss = 0
        if (params['num_con_c'] != 0):
            con_loss = criterionQ_con(noise[:, params['num_z']+ params['num_dis_c']*params['dis_c_dim'] : ].view(-1, params['num_con_c']), q_mu, q_var)*0.1

        # Net loss for generator.
        G_loss = gen_loss + dis_loss + con_loss
        # Calculate gradients.
        G_loss.backward()
        # Update parameters.
        optimG.step()

        # Check progress of training.
        if i != 0 and i%100 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'
                  % (epoch+1, params['num_epochs'], i, len(dataloader),
                    D_loss.item(), G_loss.item()))

        # Save the losses for plotting.
        G_losses.append(G_loss.item())
        D_losses.append(D_loss.item())

        iters += 1

    epoch_time = time.time() - epoch_start_time
    print("Time taken for Epoch %d: %.2fs" %(epoch + 1, epoch_time))
    # Generate image after each epoch to check performance of the generator. Used for creating animated gif later.
    with torch.no_grad():
        gen_data = netG(fixed_noise).detach().cpu()
    img_list.append(vutils.make_grid(gen_data, nrow=10, padding=2, normalize=True))

    # Generate image to check performance of generator.
    if((epoch+1) == 1 or (epoch+1) == params['num_epochs']/2) or epoch%5==0:
        with torch.no_grad():
            gen_data = netG(fixed_noise).detach().cpu()
        plt.figure(figsize=(10, 10))
        plt.axis("off")
        plt.imshow(np.transpose(vutils.make_grid(gen_data, nrow=10, padding=2, normalize=True), (1,2,0)))
        plt.savefig("Epoch_%d {}".format(params['dataset']) %(epoch+1))
        plt.close('all')


    # Save network weights.
    if (epoch+1) % params['save_epoch'] == 0:
        torch.save({
            'netG' : netG.state_dict(),
            'discriminator' : discriminator.state_dict(),
            'netD' : netD.state_dict(),
            'netQ' : netQ.state_dict(),
            'optimD' : optimD.state_dict(),
            'optimG' : optimG.state_dict(),
            'params' : params
            }, 'InfoGAN/model_epoch_%d_{}'.format(params['dataset']) %(epoch+1))


training_time = time.time() - start_time
print("-"*50)
print('Training finished!\nTotal Time for Training: %.2fm' %(training_time / 60))
print("-"*50)

# Generate image to check performance of trained generator.
with torch.no_grad():
    gen_data = netG(fixed_noise).detach().cpu()
plt.figure(figsize=(10, 10))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(gen_data, nrow=10, padding=2, normalize=True), (1,2,0)))
plt.savefig("Epoch_%d_{}".format(params['dataset']) %(params['num_epochs']))


# Save network weights.
torch.save({
    'netG' : netG.state_dict(),
    'discriminator' : discriminator.state_dict(),
    'netD' : netD.state_dict(),
    'netQ' : netQ.state_dict(),
    'optimD' : optimD.state_dict(),
    'optimG' : optimG.state_dict(),
    'params' : params
    }, 'InfoGAN/model_final_{}'.format(params['dataset']))


# Plot the training losses.
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig("Loss Curve {}".format(params['dataset']))

# Animation showing the improvements of the generator.
fig = plt.figure(figsize=(10,10))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
anim = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
anim.save('infoGAN_{}.gif'.format(params['dataset']), dpi=80, writer='imagemagick')
plt.show()

이렇게 돌리면 오류가 나는 부분이 있는데, torch.save쪽에서 오류가 납니다. 저장할때 MNIST data에 대해서 수행했다면 ‘InfoGAN/model_final_MNIST’ 에 저장이 됩니다. 즉, InfoGAN 파일에 model_final_MNIST로 저장이 되는데 저희는 코랩에서 아무것도 건드리지 않았기 때문에 InfoGAN 파일이 없죠. 그래서 직접 만들어야 합니다. 만들면 이 오류는 없어지게 됩니다!

InfoGAN 파일을 위 처럼 만드셨다면 문제없이 실행됩니다.

 

또한 label=label.to(torch.float32) 부분은 label의 type이 torch.long 형태에서 ‘loss_real = criterionD(probs_real, label)’ 이부분에서 오류가 납니다.(probs_real은 float형태기 때문에)
따라서 probs_real과 type을 동일하게 하기 위해 float으로 변경하였습니다.

 

마지막으로 if((epoch+1) == 1) or epoch%5==0: 이 부분은 epoch이 5의 배수만큼 돌았을 때 사진을 출력하도록 변경하였습니다.

이렇게 설정하고 나서 분석을 수행한 결과를 보여드리겠습니다.

Epochs 100번 iteration

아무런 정보도 없었는데, 그래도 잘 분류하네요!

 

이제 숫자 말고 $c_2$, $c_3$(Rotation,Width)을 uniform 분포에서 변경할수록 어떻게 변화하는지 살펴보도록 하겠습니다.

 

Feature

import argparse

import torch
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt


parser = argparse.ArgumentParser()
parser.add_argument('-load_path', required=True, help='Checkpoint to load path from')
args = parser.parse_args(['-load_path', 'InfoGAN/model_final_MNIST'])

# from models.mnist_model import Generator

# Load the checkpoint file
state_dict = torch.load(args.load_path)

# Set the device to run on: GPU or CPU.
device = torch.device("cuda:0" if(torch.cuda.is_available()) else "cpu")
# Get the 'params' dictionary from the loaded state_dict.
params = state_dict['params']

# Create the generator network.
netG = Generator().to(device)
# Load the trained generator weights.
netG.load_state_dict(state_dict['netG'])
print(netG)

c = np.linspace(-2, 2, 10).reshape(1, -1)
c = np.repeat(c, 10, 0).reshape(-1, 1)
c = torch.from_numpy(c).float().to(device)
c = c.view(-1, 1, 1, 1)

zeros = torch.zeros(100, 1, 1, 1, device=device)

# Continuous latent code.
c2 = torch.cat((c, zeros), dim=1)
c3 = torch.cat((zeros, c), dim=1)
# c4 = torch.cat((zeros, c), dim=1)

idx = np.arange(10).repeat(10)
dis_c = torch.zeros(100, 10, 1, 1, device=device)
dis_c[torch.arange(0, 100), idx] = 1.0
# Discrete latent code.
c1 = dis_c.view(100, -1, 1, 1)

z = torch.randn(100, 62, 1, 1, device=device)

# To see variation along c2 (Horizontally) and c1 (Vertically)
noise1 = torch.cat((z, c1, c2), dim=1)
# To see variation along c3 (Horizontally) and c1 (Vertically)
noise2 = torch.cat((z, c1, c3), dim=1)
# # To see variation along c4 (Horizontally) and c1 (Vertically)
# noise3 = torch.cat((z, c1, c4), dim=1)




# Generate image.
with torch.no_grad():
    generated_img1 = netG(noise1).detach().cpu()
# Display the generated image.
fig = plt.figure(figsize=(10, 10))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(generated_img1, nrow=10, padding=2, normalize=True), (1,2,0)))
plt.show()

# Generate image.
with torch.no_grad():
    generated_img2 = netG(noise2).detach().cpu()
# Display the generated image.
fig = plt.figure(figsize=(10, 10))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(generated_img2, nrow=10, padding=2, normalize=True), (1,2,0)))
plt.show()

 

Rotation(-1~1)

 

 

Width(-1~1)

 

결과를 보시면 오른쪽에서 왼쪽으로 갈수록 Rotaiton, Width의 특징을 학습하고 있다고 볼 수 있을 것 같습니다!

Rotation의 경우에는 조금씩 기울어지고있으며, Width의 경우에는 너비가 조금씩 커지는 것을 확인할 수 있었습니다!

논문에서와는 달리 뚜렷한 결과를 보이지 않고, 두 변수간에 섞임이 조금 있어보이는 것 같습니다. Width에서도 약간 기울어지며 생성이 되는듯한 모습을 보이고, Rotation에서도 약간 너비가 커져가는 것을 보이는 듯 합니다!

 

 

참고문헌

  1. https://github.com/Natsu6767/InfoGAN-PyTorch/tree/master

 

 

감사합니다!

 

오늘은 2016년에 NeurIPS에 발표된 InfoGAN이라는 논문을 리뷰해 보겠습니다.

1. InfoGAN이란?

InfoGAN이란 기존의 GAN 모델에서 정보이론(information-theoretic)의 개념을 추가하여 Disentangled representation을 학습할 수 있도록 하는 모델입니다.

 

Disentangled representation이란? 

데이터의 특징(feature)이나 변수가 서로 분리되어 표현된다는 것을 의미합니다.

사람의 얼굴 이미지를 다룬다고 할때 사람의 표정, 눈 색상, 헤어스타일, 선글라스 유무, 숫자 이미지를 다룬다고 할때 숫자의 크기, 두께(thickness), 각도(angle) 등의 특징이 분리되어 표현된다면 disentangled representation이라고 합니다.

분리된 feature를 통해 사람의 얼굴 이미지를 생성


기존의 GAN 모델에서는 Input data와 유사한 데이터를 만드는 것에 목적을 두었다면, InfoGAN은 유사한 데이터를 만들면서 데이터의 특징을 잘 학습하는데 중점을 둡니다. 

잘 학습되게 된다면, 숫자 데이터에서 두께, 각도 등을 다르게 생성할 수 있게 된다는 큰 장점이 있습니다.


위 그림을 보면 (a),(c),(d)의 경우 InfoGAN의 결과이고, (b)는 Original GAN의 결과입니다. 변수에 약간의 값을 변경해서 넣어주게 되면 회전하는 것과 넓이가 커지는 것을 볼 수 있습니다. (위 그림에선 c2는 회전의 특징을 학습한 변수이며 c3는 넓이의 특징을 학습한 변수라고 생각할 수 있습니다.)

 

2. Abstract

  1. ‘learn disentangled representations’
    위에서 말했듯이 정보이론의 개념을 도입해 데이터의 특징을 잘 학습할 수 있도록 학습하는 것입니다.

  2. maximizes the mutual information
    mutual information term을 maximize하는 것을 목표로 합니다. (Mutual information은 상호정보라 불리고 아래에서 자세하게 설명하겠습니다.)

  3. lower bound of the mutual information
    직접적인 mutual information term을 구할 수 없으므로 lower bound를 구해서 그 lower bound를 maximize하는 것을 목표로 합니다.

  4. learns interpretable representations
    실험결과에서 InfoGAN은 해석가능한 representation을 학습했다고 말하고 있습니다.

3. Background: Generative Adversarial Networks

GAN(Generative Adversarial Networks) 모델의 목적은 생성된 데이터 분포 $P_{G}(x)$ 가 실제 데이터 분포인 $P_{data}(x)$와 유사하게 학습하도록 하는 모델입니다. 


그러기 위해서는 생성자(Generator)와 판별기(Discriminator)를 학습하게 되는데 생성자의 경우 noise를 입력으로 받아 생성자(G)를 통해 이미지를 생성하고, 판별기(D)는 입력으로 받은 데이터가 실제 데이터인지, G가 생성해낸 이미지인지 판별하게 됩니다.


생성자(G)의 입장에서는 실제 데이터와 아주 유사하게 만들어야 하고, 판별기(D)는 실제 데이터인지, 생성자(G)가 만든 데이터인지 구분을 잘 할 수 있도록 설계되어 있다고 생각하시면 됩니다! 이 내용을 식으로 쓰게되면 아래와 같이 표현할 수 있습니다.

GAN의 목적함수

이제 G와 D입장에서 minimization, maximizatinon하는 과정을 보겠습니다.


Discriminator :
D의 입장에서는 실제 데이터를 1로 출력하고 생성된 데이터는 0이라고 출력하길 원합니다!

위 그림처럼 되면 maximization이 되겠지요!


Generator : G의 입장에서는 판별기(D)가 1로 출력해주길 원합니다!

첫 번째 항에는 G가 없으니 두 번째 항만 신경쓰면 됩니다! minimization해야하니 log0 에 수렴해야 가장 작은 값을 얻을 수 있겠죠!

 

4. Mutual Information for Inducing Latent Codes

GAN에서 noise vector(z)를 입력으로 받고 G가 데이터를 생성했었습니다. 하지만 이때 z에게 아무런 제약을 주지 않았습니다. 그렇기 때문에 Generator는 매우 꼬여있다(entangled way)고 볼 수 있습니다. 여기서 꼬여있다는 것은 z의 차원들이 data에서 의미를 가지는 feature와 대응되지 않는다고 볼 수 있습니다!

아래 사진처럼 얼굴 이미지를 생성할 때처럼 z의 차원들이 의미(특징)를 가지고 있지 않다는 것으로 보면 될 것 같습니다!

 

분리된 feature를 통해 사람의 얼굴 이미지를 생성

MNIST data의 경우만 봐도 0~9 숫자를 나타내는 digit type, 회전을 나타내는 Rotation, 너비를 나타내는 Width, 두께를 나타내는 thickness 등의 특징을 가지고 있음을 알 수 있습니다.

이런 특징에 대한 정보를 가지고서 이미지를 생성하면 원하는 데이터를 생성할 수 있지 않을까? 라는 생각에서 GAN의 확장모델인 InfoGAN이 만들어진 것입니다!

InfoGAN에서는 noise vector(z)를 두개의 파트로 나누어서 사용합니다.

  1. 압축할 수 없는 noise z
    이 z는 생성된 데이터에 무작위성이나 변동성을 주입하기 위해 사용된다고 합니다.  데이터의 구조를 설명하진 않지만, 생성된 샘플의 다양성을 증가시키는 역할을 합니다.

  2. 데이터분포에서 특징을 가지는 latent code “c”
    얼굴이미지를 생성한다고 할 때 표정, 눈 색상 등 의미를 가지는 feature를 c라고 생각하시면 됩니다.
    또한 latent code c는 factored distribution을 가정합니다.
    factored distribution은 변수들과 독립성을 가정하고 확률분포로 표현합니다.  $P(c_1,c_2,...,c_L)=\Pi_{i=1}^LP(c_i)$

저희는 이제 noise vector를 두 파트로 나누었으니 Generator에서는 z와 c를 입력으로 받아 생성해야겠죠! G(z) → G(z,c)가 됩니다!

🤔 However, in standard GAN, the generator is free to ignore the additional latent code c by finding a solution satisfying $P_G(x|c)=P_G(x)$

기존의 GAN의 모델에서는 아무런 제약을 주지 않아 $P_G(x|c)=P_G(x)$를 만족하는 solution을 찾기 때문에 c가 무시되게 됩니다.

이렇게 되는 것을 피하기 위해 우리는 추가로 제약을 걸어줘야 합니다! 그래서 저자는 information -theoretic regularization을 제안하였고, latent code c와 G(z,c)간의 상호정보량이(mutual information) 높아야 한다고 말하고 있습니다. 따라서 $I(c;G(z,c))$이 값이 높기를 원한다는 겁니다.(제가 초반에 정보이론개념을 추가할거라고 말한 부분입니다!)

정보이론에서 X와 Y 사이의 상호정보량(mutual information)은 $I(X;Y)$로 표현되고 Y 변수를 통해 X에 대해 얻어진 “정보량”이라고 하고 식은 아래와 같이 쓰일 수 있습니다.

 

 

$I(X;Y)=\Sigma_{x\in X,y\in Y}P_{X,Y}(x,y)log(\frac{P_{X,Y}(x,y)}{P_{X}(x)P_{Y}(y)})$

          $\ \ \ \ =\Sigma_{x\in X,y\in Y}P_{X,Y}(x,y)log(\frac{P_{X,Y}(x,y)}{P_{X}(x)}) - \Sigma_{x\in X,y\in Y}P_{X,Y}(x,y)(logP_Y(y))$

          $\ \ \ \ = \Sigma_{x\in X,y\in Y}P_X(x)P_{Y|X=x}(y)log(P_{Y|X=x(y)}) - \Sigma_{x\in X,y\in Y}P_{X,Y}(x,y)(logP_Y(y))$

          $\ \ \ \ = \Sigma_{x\in X}P_X(x)(\Sigma_{y\in Y}P_{Y|X=x}(y)logP_{Y|X=x}(y))-\Sigma_{y\in Y}(\Sigma_{x\in X}P_{X,Y}(x,y))logP_Y(y)$

          $\ \ \ \ =  -H(Y|X)+H(Y)  $

          $\ \ \ \ =  H(Y)-H(Y|X)   = H(X)-H(X|Y)  $

entropy term : $H(X)=-\Sigma_{x\in X}P_X(x)logP_X(x)$

 

또한 Y가 관측되었을 때 X에서 불확실성의 감소량이라고도 해석할 수 있습니다!

불확실성 감소량이 직관적으로 와닿지 않을수도 있으니 예를 들어보겠습니다.

Y(동물의 특징) : 크기, 발 개수, 이빨의 형태, 육식
X(동물) : 고양이, 얼룩말, 코끼리


동물의 특징이 주어졌을 때 동물을 예측하는것과 특징이 주어지지 않았을 때 동물을 예측하는 것 중 어느것이 더 ‘’ 예측할 수 있을까요? → 당연히 정보가 주어졌을 때 예측하기 쉽겠죠!(불확실성이 더 감소!) 이렇게 정보가 주어졌을 때 얼마나 더 잘 예측할 수 있는가를 불확실성의 감소량이라고 볼 수 있습니다!

만약 X와 Y가 독립이면

$I(X;Y)=\Sigma_{x\in X,y\in Y}P_{X,Y}(x,y)log(\frac{P_{X,Y}(x,y)}{P_{X}(x)P_{Y}(y)})$ 이 부분에서 $P_{X,Y}(x,y)=P_X(x)P_Y(y)$로 쓸 수 있기 때문에 $I(X;Y)=\Sigma_{x\in X,y\in Y}P_{X,Y}(x,y)log(1)$ =0이 되는 것을 알 수 있습니다!

직관적으로 해석해보면 동전을 던져서 나오는 결과를 생각해본다고 해보겠습니다.

첫 번째 동전을 던져서 앞면이 나왔다고 두 번째 동전이 앞면이 나올 확률이 바뀔까요? 아닙니다. 그대로 $\frac{1}{2}$인 것을 알 수 있죠. 즉, 두 번째 결과는 첫 번째 결과를 통해 얻어진 정보가 없다고 볼 수 있죠! 그렇기에 상호정보량은 0이라고 할 수 있습니다.

 

다시 돌아와서, 저희는 $P_G(x)$에서 뽑힌 x가 주어졌을 때, $P_G(c|x)$가 small entropy를 갖기를 원합니다!(small entropy는 정보의 불확실성이 적다는 것을 의미한다고 생각하시면 됩니다.)

그래서 기존 GAN의 목적함수에서 information-regularized term인 $I(c;G(z,c))$을 추가하여 이 값이 최대가 되도록 합니다! 식을 쓰면 아래와 같습니다.

 

GAN에서는 G를 minimization하도록 만들었으니 $I(c;G(z,c))$을 maximization하는 것 대신 -$I(c;G(z,c))$를 minimization하도록 term을 추가한 것입니다! ($\lambda$는 hyperparameter 입니다.)

5. Variational Mutual Information Maximization

이제 우리는 $I(c;G(z,c))$ term을 maximization[ -$I(c;G(z,c))$를 minimization하는 것과 같아서 maximization관점으로 보겠습니다.]해야합니다. 하지만, 우리는 이 식을 통해서 직접적으로 maximization할 수 없습니다. 왜그럴까요? 우선 $I(c;G(z,c))$ 식을 전개해 보겠습니다.

 

식이 되게 복잡해 보입니다… 차근차근 하나씩 살펴보죠!! 형광색으로 칠한 부분 먼저 보겠습니다.

$I(X;Y)$에서 {X ← c, Y ← G(z,c)}대입하시면 $I(c;G(z,c))$ 이렇게 위 식처럼 나오게 됩니다.

참고) $H(X)-H(X|Y)=\Sigma_{y\in Y}(\Sigma_{x\in X}P_{X|Y=y}(x)logP_{X|Y=y}(x))-\Sigma_{x\in X}(\Sigma_{y\in Y}P_{X,Y}(x,y))logP_X(x)$


그 다음 칠해져 있는 부분이 왜 저렇게 분리가 되는지 보겠습니다!

위와 같은 식($I(c;G(z,c))$)을 maximization하기 위해서는 $P(c|x)$의 분포에서 뽑은 샘플 $c'$이 필요합니다. 즉, 우리는 데이터가 주어졌을 때 그 데이터의 특징(헤어스타일, 표정 등)을 알고싶은 것입니다!
하지만 우리는 알지 못하기 때문에 auxiliary distribution(보조분포) Q를 사용하여 근사하려 합니다. 보조분포는 보통 우리가 흔히 알고있는 Gaussian distribution으로 설정합니다! 그렇게 $Q(c|x)$를 설정하게되면, $P(c|x)$의 분포와 $Q(c|x)$의 분포간 차이가 있겠죠?? 그래서 분포들 간 거리를 측정하는 지표로 KL divergence라는 함수로 거리를 측정합니다.

$$ D_{KL}(P(|x)||Q(|x))=\Sigma_{c'}p(c'|x)log\left(\frac{P(c'|x)}{Q(c'|x)}\right)=\mathbb{E}\left[log\left(\frac{P(c'|x)}{Q(c'|x)}\right) \right] $$

 

위 식을 통해 KL divergence는 다음과 같이 나눌 수 있습니다.

$D_{KL}(P(|x)||Q(|x))=\mathbb{E}_{c'\sim P(c|x)}[logP(c'|x)]-\mathbb{E}_{c'\sim P(c| x)}[logQ(c'|x)]$ 항을 넘기게 되면

$\mathbb{E}_{c'\sim P(c|x)}[logP(c'|x)]=D_{KL}(P(|x)||Q(|x))+ \mathbb{E}_{c'\sim P(c| x)}[logQ(c'|x)]$식으로 바꿀 수 있습니다!

 

$D_{KL}(P(|x)||Q(|x))$식 아래에 보면 $\geq$0 으로 표현되어 있습니다. 왜 이 식은 항상 0보다 클까요? 분포들간의 ‘거리’를 측정하는 것이기 때문에 항상 0보다 크다고 보면 됩니다! 직관적으로 해석도 되지만, 수식적으로 증명도 해보겠습니다!

$$ D_{KL}(P(|x)||Q(|x))=\mathbb{E}\left[log\left(\frac{P(c'|x)}{Q(c'|x)}\right) \right]=\mathbb{E}\left[-log\left(\frac{Q(c'|x)}{P(c'|x)}\right) \right] $$

위의 식을 그대로 갖고와서 log에음수를 취해서 분모 분자를 바꿨습니다. 편의를 위해 $\frac{Q(c'|x)}{P(c'|x)}=Z$라고 설정하겠습니다.

$\mathbb{E}\left[-log(Z)\right]$로 표현이 됩니다. 그런데 이 식을 자세히 보면 $-log(x)$와 $\mathbb{E}$의 형태로 이루어진 것을 볼 수 있습니다. $-log(x)$는 convex function이기 때문에 Jensen’s inequality를 적용할 수 있죠!


$D_{KL}(P(|x)||Q(|x))=\mathbb{E}\left[-log(Z)\right]\geq-log\mathbb{E}\left[Z\right]=-log\Sigma_{c'}P(c'|x)\frac{Q(c'|x)}{P(c'|x)}=-log(1)=0$

따라서 $D_{KL}(P(|x)||Q(|x))\geq 0$을 만족하는 것입니다!


마지막으로 $D_{KL}(P(|x)||Q(|x))$ 식은 항상 0보다 크기때문에 이 식을 삭제하면 저 부등호가 성립한다는 것도 알 수 있습니다! 결국 하고싶은건 $I(c;G(z,c))$을 maximization을 하고싶은데, $P(c|x)$를 알지 못하니까 보조분포인 $Q(c|x)$를 이용해서 lower bound를 만들고 저 lower bound를 maximization 하고자 하는 것 입니다! 이것을 variational mutual information maximization라 합니다.

하지만 여기서 또 한가지 문제점이 있죠. 저희는 $P(c|x)$를 모르기 때문에 $Q(c|x)$를 사용한다고 했는데, 마지막에 구한식을 보면 $c'$을 $P(c|x)$분포에서 뽑고있습니다.

 

저자는 이 Lemma를 사용하여 값 교묘하게 변경하였습니다!

🔑 Lemma 5.1 For random variables X,Y and function $f(x,y)$under suitable regularity conditions: $\mathbb{E}{x\sim X,y\sim Y|x}\left[f(x,y)\right]=\mathbb{E}{x\sim X,y\sim Y|x,x'\sim X|y}\left[f(x',y)\right]$.

약간의 트릭을 사용하여 이렇게 변환이 가능하다고 하네요! 이 자료는 고려대학교 임성빈 교수님께서 2가지 방식으로 증명해주신 자료입니다.

첫 번째 증명

 

2번째 줄에서 3번째줄에서 x→ x’ 으로 rename했는데, 확률변수를 바꾼게 아니라 적분할 때 변수 표기를 바꾼거에 불과하다고 합니다. 그래서 성립하는 것을 볼 수 있죠.


2번째 증명에서는 Law of total expectation(이중 기대값) 정리를 사용하여 증명하셨습니다.

Law of total expectation 증명을 먼저 보시고 아래의 사진을 참고하시면 됩니다!

🔑 $\mathbb{E}\left[\mathbb{E}\left[X|Y\right]\right]=\mathbb{E}\left[X\right]$

참고) $\mathbb{E}[Y|X]$는 X에 대한 함수입니다!

 

 

위 식이 만족하기 때문에 이렇게 되는 것을 볼 수 있습니다.

 

$I(c;G(z,c))$을 maximization하는 것 대신에 lower bound인 $L_1(G,Q)$을 maximization하자!!

따라서 infoGAN의 최종 목적식은 다음과 같습니다. ($\lambda$는 hyperparameter입니다.)


6. Experiments

실험을 통해 저자는 2가지를 달성하고자 합니다.

첫 번째 : 실제로 mutual information이 maximization이 되는지

두 번째 : InfoGAN이 구분되고 해석 가능한 representation을 학습하는지(사람의 표정, 헤어스타일 등의 특징을 잘 학습했는지)

 

6-1. Mutual Information Maximization

latent codes c와 G(z,c)간의 mutual information을 평가하기 위해서 MNIST 데이터셋을 사용했습니다. $c \sim Cat(K=10,p=0.1)$의 분포로 설정하고나서 Lower bound를 각 iteration(반복)마다 $H(c)$값을 기록했습니다. $H(10) \approx 2.30$으로 빠르게 maximization되는 것을 확인할 수 있습니다.

Cat(카테고리 분포)의 pmf를 나타내면 다음과 같습니다.

 

Lower bound L1 over training iteration

 

6-2. Disentabgled Representation

MNIST 데이터셋에서 Disentabgled Representation을 잘 학습했는지 확인하기 위해 latent code $c_1,c_2,c_3$를 추가했는데 $c_1 \sim Cat(K=10,p=0.1)$이며 $c_2,c_3$변수는 연속형 변수로 $Unif(-1,1)$을 사용하였습니다.

$c_1$ 변수의 경우 label에 대한 정보도 없이 0~9까지의 숫자를 잘 생성해낼 수 있는 것을 볼 수 있었습니다.

$c_2,c_3$ 변수의 경우 $c_2$는 숫자의 rotation(회전)에 대한 변수이고 $c_3$는 숫자의 Width(너비)에 대한 변수인것을 확인할 수 있습니다. 논문에서는 $Unif(-2,2)$를 사용해서 결과를 보여주고 있습니다.(더 극명한 결과를 보여주기 위함입니다.) latent code가 이런 특징들을 잘 포착한 걸로 보아 Disentabgled Representation을 잘 학습했다고 볼 수 있습니다.

즉, InfoGAN을 통해서 Mutual Information Maximization, Disentabgled Representation을 모두 달성했다고 볼 수 있습니다!

 

7. Conclusion

이 논문은 “Information Maximizing Generative Adversarial Networks”(InfoGAN) 이라고 불리는representation 학습 알고리즘을 소개했습니다. supervision을 요구로 하는 이전의 접근방법들과 다르게, InfoGAN은 비지도학습으로 해석과 분리가능한 representation을 학습하였다는 것입니다. 또한 GAN과 연산시간이 거의 비슷하다고 합니다!

 

 

8. 참고문헌

[1] InfoGAN : https://arxiv.org/abs/1606.03657

[2] 유재준님 블로그 : https://jaejunyoo.blogspot.com/2017/03/infogan-1.html

[3] 하우론브레인님 블로그 : https://haawron.tistory.com/10

[4] Mutual information : https://en.wikipedia.org/wiki/Mutual_information

[5] EECS498 Generative Models Part 2 : https://www.youtube.com/watch?v=igP03FXZqgo&list=PL5-TkQAfAZFbzxjBHtzdVCWE0Zbhomg7r&index=21

 

다음에는 InfoGAN을 코드로 구현해보도록 하겠습니다.

+ Recent posts