Ddanggle in Ml minutes

(Pytorch를 사용한) 단 50줄로 코드로 짜보는 GAN

이 글은 저자 Dev Nag의 허락을 받아 (Pytorch를 사용해서) 단 50줄로 코드로 짜보는 GAN의 듀토리얼 글을 번역한 것입니다. 원문도 꼭 읽어보셨으면 합니다.

GAN은 생각보단 간단합니다.

GAN은 지난 10년동안 머신러닝 업계에 나온 가장 뛰어난 아이디어!-얀 르쿤

2014년 몬트리올 대학의 이안 굿 펠로와 그의 동료들은 GAN, Generative Adversarial Networks을 소개하는 걸출한 논문을 발표했습니다. 컴퓨터 모델과 게임이론이 합쳐진 획기적인 조합으로, 충분한 모델링 파워가 주어진다면, 두 모델은 서로 치고박으면서 오래되고 일반적인 역전파방법이 둘-모두 트레이닝 시키는 방법입니다.

두 모델은 완전히 구별되는 (말 그대로, 적대적인 ) 관계입니다. 주어진 실제 데이터셋 **R**, **G** 등은 단순히 진짜처럼 보이도록 하는 가짜 데이터를 생성하는 __생성자__ 이고, **D**는 실제 데이터 셋 혹은 **G**에서 데이터를 가지고 와서 다른점을 분류하는(라벨링하는) __식별자__ 입니다. 굿 펠로는 **G**는 진품을 자신의 결과물로 위조해서 보이게하는 위조집단이고 **D**는 위조품과 진품의 다른 점을 찾는 형사집단'이라 할 수 있다고 (꽤 적절하게) 비유했습니다. (그렇지만 예외적으로 이번 경우에는 위조집단 **G**는 절대 실제 데이터를 볼 수 없고 **D**의 판단만 볼 수 있습니다. **장님** 위조집단인 것이죠)

이미지

이상적인 경우에는 DG 둘 모두 G 가 실제 작품의 완전한 “위조전문가”가 되고, D가 패할 때까지 계속 성장합니다. D가 “더 이상 다른 점을 찾을 수가 없습니다.” 라고 선언할 때까지요.

실행하는 동안, 굿 펠로가 설명했듯이 G는 실제 데이터에서, (가능한) 더 낮은 차원의 방식안에서 데이터를 대표하는 어떤 방법을 찾는 비지도학습 방식으로 수행합니다. 얀 르쿤의 명언을 빌리자면, 비지도 학습은 실제 AI의 “케잌”과 같은 존재 입니다.

=====

이런 걸출한 방법을 단순히 시작하는 것만도 엄청난 양의 코드가 필요해보이죠? 사실 별로 많진 않습니다. PyTorch를 사용하면, 50줄도 안되는 코드로 간단한 GAN을 만들어볼 수 있습니다. 만들기 위해서는 5개의 컴포넌트만 생각하면 됩니다.

  • R: 실제 진짜 데이터셋
  • I: 생성자에 불확실성(entorpy)의 자원으로 들어가서 사용되는 랜덤노이즈
  • G: 실제 데이터 셋을 복제/흉내 낼 생성자
  • D: G의 결과물이 실제 데이터셋 R과 얼마나 다른 지 말해주는 식별자
  • G 가 D를 흉내 내도록 가리키고, DG에게 주의를 주도록 상기시켜주는 실행되는 실제 ‘트레이닝’ 루프

1) R: 이번 예시에서는, 가장 간닪란 벨 커브를 R로 사용하겠습니다. 벨 커브는 평균과 표준편차를 가지고 가우시안의 파라미터들로 샘플 데이터들의 곡선을 반환하는 함수입니다. 아래 예시 코드에서는 평균으로 4.0을, 표준편차로 1.25를 사용하겠습니다.

def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tenser(np.random.normal(mu, sigma, (1, n))) # 가우시안

2) I : 생성자에 넣는 입력값 또한 랜덤이지만, 이 일을 조금 더 어렵게 만들기 위해 평범한 것 대신에 연속균등분포(uniform distribution)를 사용하겠습니다. 그러면 모델 GR을 간단하게 옮기거나 범위들을 그 옮기기 힘듭니다. 그렇지만 비선형적으로 데이터의 분포를 바꿔버립니다.

def get_generator_intput_sampler():
    return lambda m, n: torch.rand(m, n)

3)G: 생성자는 두 층의 히든 레이어와 세 개의 선형사상(linear map)을 가지고 있는 표준 전방전달 (feedforward) 그래프입니다. 우리는 ELU (자연로그 선형 유닛)를 사용할 것입니다. 왜냐면 제일 최신의 힙한 것이거든요! GI의 균등하게 분배된 데이터 샘플을 계속 제공받고 어떻게든 R의 가우스분포(정규분포) 샘플을 흉내내고 있습니다.

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.sigmoid(self.map2(x))
        return self.map3(x)

4) D: 이 식별자 코드는 G의 제너레이터(생성) 코드와 비슷합니다. 2층의 히든레이어와 3개의 선형사항으로 구성된 표준 전방전달 그래프입니다. R이나 G에서 가져온 예시로 ‘가짜’인지 ‘진짜’인지 판단하고 0과 1사이의 스칼라 값을 결과로 내놓습니다. 하지만, 뉴럴 네트워크가 내놓는 만들다 만 것 같은 값입니다.

(역주 :This is about as milquetoast as a neural net can get. 마지막 문장이 어색해서 추가 해놓습니다. 대충 의미는 감이 잡히는 데 저도 제대로 이해 못한 것 같네요. milquetoast 는 겁쟁이, 비겁한 놈 같은 의미인데, 여기에서 어떻게 적용될 지 모르겠습니다. 코멘트 주시면 바로 반영할게요! )

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self, map1(x))
        x = F.elu(self, map2(x))
        return F.sigmoid(self.map3(x)) 

5) 마지막으로, 아래의 두 방법을 번갈아서 반복적으로 훈련하게됩니다. 먼저, 정확한 분류(라벨들) ( 경찰학교 라고 생각하시면 됩니다)로 진짜 데이터와 가짜 데이터를 싸움 붙여서 D를 훈련시키고, 이 바보 D로 정확하지 않은 라벨들로 G를 훈련시킵니다. (오션스 일레븐 에서처럼 준비차원에서 가짜로 해보는 겁니다). 말 그대로 선과 악이 싸우는 거죠.

imgofcode

# 초록색부분
for epoch in range(num_epochs):
    for d_index in range(d_steps):
        # 1. D를 진짜와 가짜 데이터로 훈련시킨다. 
        D.zero_grad()

        #  1A: 실제 데이터로 D를 훈련시킨다. 
        d_real_data = Variable(d_sampler(d_input_size))
        d_real_decision = D(preprocess(d_real_data))
        d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))  # ones = true
        d_real_error.backward() # 마분값을 계산 / 저장하지만, 파라미터를 바꾸지는 않는다. 

        #  1B: 가짜 데이터로 D를 훈련시킨다. 
        d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        d_fake_data = G(d_gen_input).detach()  # G가 훈련할 때 이 라벨들을 피하도록 빼버린다. 
        d_fake_decision = D(preprocess(d_fake_data.t()))
        d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))  # zeros = fake
        d_fake_error.backward()
        d_optimizer.step()     # 오직 D의 파라미터만 최적화한다. backward()로 게산된 저장된 미분값에 의해 바뀐다.
# 빨간색부분
for g_index in range(g_steps):
        # 2. D의 응답을 활용해서 G를 트레이닝 시킨다. (그렇지만 D를 이 라벨로 트레이닝 하지는 않는다. )
        G.zero_grad()

        gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        g_fake_data = G(gen_input)
        dg_fake_decision = D(preprocess(g_fake_data.t()))
        g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))  # 우리는 바보를 원하기 떄문에 모든 게 진짜 것 처럼 행동한다.

        g_error.backward()
        g_optimizer.step()  # 오직 G의 파라미터만 최적화한다. 

Pytorch를 한 번도 보지못했더라도, 어떤 일이 일어나는 지 이해할 수 있을 것입니다. 첫 번째(녹색) 부분을 보면, DD가 추측한 것과 실제 라벨들 둘 다 구별가능한 기준을 적용시켜 전방(forward) 단계에서 통과시킵니다. 그리고 d_optimizer.step()에서 D의 파라미터를 업데이트하기 위한 경사도 계산을 위해 backward()를 호출합니다. 여기서 G는 사용은 되지만 트레이닝되지는 않았습니다.

그리고 두 번째(빨간색) 부분을 보면, G 에도 동일한 방법을 적용시킵니다. D에 통과시키면서 결과값 G 를 내놓다는 것도 명심하세요. (위조범에게 형사가 직접 연습시키는 것입니다) 하지만, 이 단계에서 D를 최적화하거나 변화시키지는 않을 것입니다. 형사 D가 잘못 분류된(라벨) 것으로 훈련되면 안 되니까요. 그래서 g_optimizer.step()만 호출합니다.

그리고.. 이게 전부입니다! 다른 몇 가지 기본 코드들이 좀 더 있지만, 구체적인 GAN 특징은 이 5가지 컴포넌트가 끝입니다. 더 없습니다.

DG가 이 부딪치는 금지된 춤을 게속 추고 나면 무엇을 얻을 수 있을까요? (G가 천천히 좋아질 동안) 구별자 D는 매우 빠른 속도로 좋아집니다. 그렇지만 어느 정도 힘을 얻고 나서 G가 상대가 되는 적수가 되면 향상되기 시작합니다. 정말 미친듯이 향상됩니다.

2만번의 트레이닝이 지나고나면, G의 결과값의 평균이 4.0을 넘어가지만, 적절한 단계로 진입해서 좀 안정을 찾게 됩니다. (왼쪽사진) 이와 비슷하게도 표준편차도 처음에는 완전히 잘못된 방향에 있다가도 R값과 매칭되는 값인 기대했던 1.25 범위까지 도달했습니다.

img1 img2

기본적인 통계는 R과 매칭됩니다. 그러면 더 높은 순간은 어떨까요? 분포형태가 그대로 유지될까요? 다 끝나고 나면, 평균이 4이고, 표준편차가 1.25인 균등분포를 이루게 됩니다. 그러나 이건 R과 완전 동일하게 매칭되지는 않습니다. 그러면 마지막으로 G를 활용해서 최종 분포를 만들어보겠습니다.

img3

나쁘지 않죠? 왼쪽꼬리가 오른쪽보다 조금 더 길지만, 왜곡도나 뾰족한 정도를 보면 원래의 가우시안을 떠올리게 합니다.

G는 정말 비슷하게 진짜 분포인 R가 되었습니다. - 그리고 D는 왼쪽 끝에 있어서 이 이야기에 자기 주장이 전혀 반영되지 않습니다. 50줄도 안되는 코드로 우리가 원하는 것을 정확하게 만들어냈습니다.(굿펠로우의 첫번째 그림을 보세요)

굿펠로우는 GAN에 관련된 2016 gem describing some practical improvements나 미니배치 방법을 적절히 활용하는 등의 다른 많은 논문들을 썼습니다. 여기 NIPS2016에서의 두 시간짜리 듀토리얼도 있습니다. 텐서플로우 유저들은 Aylien의 글을 참조해보세요.

자, 이정도면 되었습니다. 코드를 한 번 보시죠


댓글