รีเซ็ตพารามิเตอร์ของโครงข่ายประสาทเทียมใน pytorch

ฉันมีโครงข่ายประสาทเทียมที่มีโครงสร้างดังต่อไปนี้:

class myNetwork(nn.Module):
    def __init__(self):
        super(myNetwork, self).__init__()
        self.bigru = nn.GRU(input_size=2, hidden_size=100, batch_first=True, bidirectional=True)
        self.fc1 = nn.Linear(200, 32)
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        self.fc2 = nn.Linear(32, 2)
        torch.nn.init.xavier_uniform_(self.fc2.weight)

ฉันต้องคืนสถานะโมเดลให้เป็นสถานะที่ไม่ได้รับการเรียนรู้โดยการรีเซ็ตพารามิเตอร์ของโครงข่ายประสาทเทียม ฉันสามารถทำได้สำหรับ nn.Linear เลเยอร์โดยใช้วิธีการด้านล่าง:

def reset_weights(self):
    torch.nn.init.xavier_uniform_(self.fc1.weight)
    torch.nn.init.xavier_uniform_(self.fc2.weight)

แต่หากต้องการรีเซ็ตน้ำหนักของเลเยอร์ nn.GRU ฉันไม่พบตัวอย่างข้อมูลดังกล่าว

คำถามของฉันคือเราจะรีเซ็ตเลเยอร์ nn.GRU ได้อย่างไร วิธีอื่นในการรีเซ็ตเครือข่ายก็ใช้ได้เช่นกัน ความช่วยเหลือใด ๆ ที่ชื่นชม


person learner    schedule 28.08.2020    source แหล่งที่มา


คำตอบ (2)


คุณสามารถใช้วิธี reset_parameters บนเลเยอร์ได้ ตามที่ให้ไว้ ที่นี่

for layer in model.children():
   if hasattr(layer, 'reset_parameters'):
       layer.reset_parameters()

หรืออีกวิธีหนึ่งคือการบันทึกโมเดลก่อนแล้วจึงโหลดสถานะโมดูลอีกครั้ง การใช้ torch.save และ torch.load ดูเอกสารเพิ่มเติม หรือ การบันทึกและการโหลดโมเดล

person Dishin H Goyani    schedule 28.08.2020

ใหม่สำหรับ pytorch ฉันสงสัยว่านี่อาจเป็นวิธีแก้ปัญหา :)

สมมติว่าโมเดลสืบทอดมาจาก torch.nn.module

หากต้องการรีเซ็ตเป็นศูนย์:

dic = Model.state_dict()
for k in dic:
    dic[k] *= 0
Model.load_state_dict(dic)
del(dic)

เพื่อรีเซ็ตแบบสุ่ม

dic = Model.state_dict()
for k in dic:
    dic[k] = torch.randn(dic[k].size())  
Model.load_state_dict(dic)
del(dic)
person Jiayi Pan    schedule 19.07.2021