개발공부

텐서플로우의 모델, 네트워크, 웨이트를 저장하고 불러오는 방법 본문

Python/Deep Learning

텐서플로우의 모델, 네트워크, 웨이트를 저장하고 불러오는 방법

mscha 2022. 6. 14. 11:45

모델명 = model

import tensorflow as tf

 

학습시킨 모델을 저장하는 방법

1. 폴더로 저장

model.save('저장할 이름')

 

2. 파일로 저장

model.save('저장할 이름.h5')

 

저장한 모델을 불러오는 방법

tf.keras.models.load_model('저장한 이름')

 

네트워크만 저장하기

my_network = model.to_json()

with open('저장할 이름.json', 'w') as json_file :
    json_file.write(my_network)

 

저장한 네트워크 불러오기

with open('저장한 이름.json', 'r') as json_file :
    my_net2 = json_file.read()

아래 model2 는 네트워크만 가져온 것이지, 학습완료된 웨이트를 가져온 것이 아니라,
현재 웨이트는 학습이 안된, 랜덤으로 셋팅된 웨이트다.
따라서 이것으로 예측을 수행하면 안된다.

# 모델 생성
model2 = tf.keras.models.model_from_json(my_net2)

 

웨이트만 저장하기

model.save_weights('저장할 이름.h5')

 

저장한 웨이트를 학습이 안된 모델에 불러오기

model2.load_weights('저장한 이름.h5')