IT/AI\ML

[python/Tensorflow2.X] "ValueError: tf.function-decorated function tried to create variables on non-first call" 에러 해결하기

개발자 두더지 2022. 8. 16. 19:37
728x90

일본의 한 블로그 글을 번역한 포스트입니다. 오역 및 의역, 직역이 있을 수 있으며 틀린 내용이 있으면 지적해주시면 감사하겠습니다.

 

 

에러가 발생하는 예제 코드


@tf.function
def call(model1: tf.keras.models.Model, inputs: tf.Tensor):
    return model1(inputs)


if __name__ == '__main__':
    model1 = tf.keras.Sequential([
        tf.keras.layers.Dense(16),
        tf.keras.layers.Dense(4)

    ])
    model2 = tf.keras.Sequential([
        tf.keras.layers.Dense(16),
        tf.keras.layers.Dense(4)
    ])
    inputs = tf.ones((10, 10), dtype=tf.float32)
    call(model1, inputs)  # raises no error!
    call(model2, inputs)  # raises an error! "tf.function-decorated function tried to create variables on non-first call"

 tf.keras.models.Model을 입력하려고한 tf.function에서 "tried to create variables on non-first call"가 발생하는 예제 코드이다. 

 포인트는  keras의 모델은 제일 처음에 호출됐을 때 가중치 행렬 등의 Variable을 build한다는 것이다. 따라서 이 call 함수에 빌드되어 있지 않은 두 가지 모델을 인수로써 전달하려고 하면 두 번째 호출시 두 번째 모델의 무게를 빌드하려다가 에러가 발생하게 된다는 것이다. 

 

 

해결책


# @tf.function를 직접 붙이지 않는다.
def call(model1: tf.keras.models.Model, inputs: tf.Tensor):
    return model1(inputs)


if __name__ == '__main__':
    model1 = tf.keras.Sequential([
        tf.keras.layers.Dense(16),
        tf.keras.layers.Dense(4)

    ])
    model2 = tf.keras.Sequential([
        tf.keras.layers.Dense(16),
        tf.keras.layers.Dense(4)
    ])
    inputs = tf.ones((10, 10), dtype=tf.float32)
    # python에서는 함수도 오브젝트이다. 각가의 모델 전용 함수를 생성하도록 한다.
    model1_call = tf.function(call)
    model2_call = tf.function(call)
    model1_call(model1, inputs)
    model2_call(model2, inputs)

tf.function을 함수의 정의 부분에 직접 쓰지 않고 각 모델 전용의 함수 오브젝트를 생성하도록 한다. 


참고자료

https://qiita.com/Yosemat1/items/6aeca92cb65b052cbafd

728x90