IT/AI\ML

[python/Tensorflow2.x] Mixed Precision API

개발자 두더지 2022. 1. 24. 23:17
728x90

Mixed Precision API란?


 기존의 뉴럴 네트워크 학습시에 가중치와 활성값으로써 FP32(32bit 부동소수점)가 이용됐다. 그러나 심층 학습의 성공과 뉴럴 네트워크의 사이즈가 커지면서 최근에 뉴럴 네트워크의 학습에 걸리는 계산 비용이 급격하게 증가하게 됐다. 이러한 문제는 제품 발매를 위한 시행착오를 거치는 동안 시간이 너무 많이 걸리는 문제가 생겼다. 

 문제를 해결하기 위해 심층 학습을 위하 하드웨어를 제공하는 기업(예를 들면 NVDIA)은 계산을 고속화하기 위한 액셀러레이터를 도입하였다. 예를 들어, NVIDIA의 Volta 세대 이후의 GPU는 계산을 고속화하기 위해 Tensor Cores를 제공하고 있다. 

 그러나, FP16(16bit 부동소수점)을 가중치, 활성값, 기울기에 사용하는데 있어서, FP16의 표현력은 FP32와 비교하면 커다란 제한이 있다. 즉, FP16의 경우, 기울기의 값의 오버플로나 다운 플로가 종종 발생하며, 이로 인해 뉴럴 네트워크의 퍼포먼스에 악영향을 미치게 된다. 

Mixed Precision를 이용한 학습에서는 FP32 네트워크를 이용한 것과 같은 동일한 결과를 유지하면서 문제를 피할 수 있는 방법의 하나이다. 조금 더 자세한 설명은 Mixed Precision API에 대해서는 공식 튜토리얼에도 잘 설명되어 있으며, 이 기능을 베이스로 한 Mixed Presision Training이라는 논문이 존재하므로 참고가 되리라고 생각한다. 

 

 

Mixed Precision API 사용법


Mixed Precision API를 사용하기 위해서는 weight를 정의할 때 dtype을 명시적으로 지정하지 않는 것이 아닌 프로그램의 선두 부분에 정책을 지정한다. 예를 들어 아래의 코드에서는 "mixed_float16"이라는 정책을 지정했다.

policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)

 이 "mixed_float16"이라는 정책에서는 변수가 float32으로 유지시켜두지만, 계산 자체는 float16으로 진행된다. float32는 기울기의 업데이트시에 사용되어, 수치 계산의 안전성을 담보한다. 또한, 행렬 연산에서는 요소마다 곱을 한 결과에 대한 합은 float32로 실행된다. 

 현재 지정된 정책이 어떻게 float32와 float16으로 나눠 사용되고 있는가는 아래의 코드로 확인 가능하다. 

print('Compute dtype: %s' % policy.compute_dtype)
print('Variable dtype: %s' % policy.variable_dtype)
Compute dtype: float16
Variable dtype: float32

 참고로, 손실함수에 전달된 값이 float16이라면, 기울기 계산시에는 오버 플로/다운 플로 문제가 발생할 수 있으므로, 아래와 같이 마지막 층은 float32로 해두는 것이 정석인 것 같다.

x = layers.Dense(10, name='dense_logits')(x)
outputs = layers.Activation('softmax', dtype='float32', name='predictions')(x)

 한편, NVIDIA의 경우 TensorRT6.0이 지원돼 기본적으로 사용되고 있는듯하다.


참고자료

https://nnabla.readthedocs.io/ja/latest/python/tutorial/mixed_precision_training.html

https://qiita.com/ohtaman/items/ed46e8c2544f6b13bccd

https://www.tensorflow.org/guide/mixed_precision?hl=ko

728x90