IT/AI\ML

[python/keras] 간단한 2차원 문제로 WGAN-gp의 기초를 이해하기

개발자 두더지 2022. 2. 25. 23:22
728x90

일본의 한 블로그 글을 번역한 포스팅입니다. 오역, 의역, 직역이 있을 수 있으며 내용 오류가 있다면 지적 부탁드립니다. 

WGAN-gp에 대해서 이해하기 위해, WGAN-gp를 간단한 2차원 문제를 적용해 그 결과를 살펴보고자한다. 

 

 

GAN과 WGAN-gp


1. GAN에 대해서

 GAN에서는 식별 모델이 판정한 내용을 바탕으로 생성 모델을 학습하는 것으로 생성 데이터의 분포를 실제 데이터와 가깝게 한다. 이것은 Jensen-Shannon divergence이라는 지표를 사용하여 실제 데이터와 생성 데이터의 확률 분포 차를 측정해, 그것을 최소화해나간다는 컨셉에 근거하고 있다.

 여기서 Jensen-Shannon divergence에는 기울기(구배) 소실이 일어나기 쉬워서, 모드 붕괴가 발생하는 문제가 발생하는 경우가 있다.

 

2. WGAN-gp에 대해서

 WGA-gp는 Wasserstein 거리에 의해 실제 데이터와 생성 데이터의 확률 분포의 차를 측정하여, 그것을 최소화하는 컨셉에 기초로 하고 있다. discriminator의 역할은 "실제 데이터인지 가짜 데이터인지를 식별한다" 즉 "실제 데이터와 생성 데이터의 Wasserstein 거리를 추정한다"가 된다. 

 Wasserstein 거리를 사용하기 위해서, GAN의 손실함수와 다르다. "손실함수 = Wasserstein거리 + 기울기의 제약" 이미지라고 생각하면 된다. 기울기의 제약이 추가된 이유는 Wasserstein 거리를 사용하기 위한 제약을 만족시키기 위해서이다(gp=gradient penalty).

 

 

간단한 2차원 문제에 WGAN-gp의 적용


 간단한 2차원 문제에 WGAN-gp를 적용하여 데이터가 생성되는 것을 관찰해보고자 한다. mnist의 문자 생성 등이 WGAN-gp의 도입관련되서 자주 소개되는데 데이터 분포이라는 관점에서는 관찰이 어렵기 때문에 2차원 데이터를 선택하였다.

 

목적

  • 실제 데이터의 데이터 공간에 속하는 데이터를 생성모델이 생성하고 있는 것을 확인하기
  • 구현 방법을 배우기

 

문제 설정

  • 실제 데이터와 비슷한 데이터를 생성하는 GAN 모델을 만드는 것
  • 학습 순서는 위와 동일
  • 실제 데이터는 2차원 공간 상의 아래의 이미와 같은 점

 

코드

 

GitHub - statsu1990/gan_simple_2d_problem: Applied gan to a simple two-dimensional problem to understand the features of gan.

Applied gan to a simple two-dimensional problem to understand the features of gan. - GitHub - statsu1990/gan_simple_2d_problem: Applied gan to a simple two-dimensional problem to understand the fea...

github.com

 

결과와 고찰

WGAN-gp와 GAN의 학습이 진행되는 모습은 아래의 그림과 같다.

실제 데이터
WGAN-gp의 생성 데이터의 학습 진행 모양

 

GAN의 생성 데이터의 학습 진행 모양

 

 위와 같은 결과로 다음의 것을 알 수 있었다.

  • 제일 처음은 생성 데이터가 랜덤이지만, 학습이 진행됨에 따라 실제 데이터와 비슷한 분포가 된다.
  • GAN은 그다지 수습되고 있지 않는 것으로 보이지만, WGAN-gp는 수습되고 있다(epoch가 늘어나도 분포의 변화가 적다).
  • GAN에 비해서 WGAN-gp는 모드 붕괴가 적다
  • WGAN-gp가 좋다.

 WGAN-gp에서는 손실 함수에 기울기 정보를 사용한다고 설명하였지만, 기울기에 관한 손실함수를 어느 정도 고려할 것인가에 대한 하이퍼 파라미터가 있다. "손실함수 = Wasserstein거리 + b*기울기의 제약" 식의 b이다. 

 b가 크면 클수록 다음과 같은 결과가 된다.

 활성화 함수는 생성 모델과 식별 모델 양족 모두 learkyRelu를 사용했다. GAN일 때는 식별 모델에 Relu를 사용하지 않으면 제대로 학습이 진행되지 않았는데 이는 GAN과 WGAN-gp에서의 식별 모델 역할을 차이가 있기 때문이라고 생각된다.


참고자료

https://st1990.hatenablog.com/entry/2019/07/14/190854

 

728x90