Setel ulang parameter jaringan saraf di pytorch

Saya memiliki jaringan saraf dengan struktur berikut:

class myNetwork(nn.Module):
    def __init__(self):
        super(myNetwork, self).__init__()
        self.bigru = nn.GRU(input_size=2, hidden_size=100, batch_first=True, bidirectional=True)
        self.fc1 = nn.Linear(200, 32)
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        self.fc2 = nn.Linear(32, 2)
        torch.nn.init.xavier_uniform_(self.fc2.weight)

Saya perlu mengembalikan model ke keadaan yang belum dipelajari dengan mengatur ulang parameter jaringan saraf. Saya dapat melakukannya untuk nn.Linear lapisan dengan menggunakan metode di bawah ini:

def reset_weights(self):
    torch.nn.init.xavier_uniform_(self.fc1.weight)
    torch.nn.init.xavier_uniform_(self.fc2.weight)

Namun, untuk mengatur ulang bobot lapisan nn.GRU, saya tidak dapat menemukan cuplikan seperti itu.

Pertanyaan saya adalah bagaimana cara mereset lapisan nn.GRU? Cara lain untuk mengatur ulang jaringan juga baik-baik saja. Bantuan apa pun dihargai.


person learner    schedule 28.08.2020    source sumber


Jawaban (2)


Anda dapat menggunakan metode reset_parameters pada layer. Seperti yang diberikan di sini

for layer in model.children():
   if hasattr(layer, 'reset_parameters'):
       layer.reset_parameters()

Atau Cara lain adalah dengan menyimpan model terlebih dahulu dan kemudian memuat ulang status modul. Menggunakan torch.save dan torch.load lihat dokumen untuk informasi lebih lanjut Atau Model Penyimpanan dan Pemuatan

person Dishin H Goyani    schedule 28.08.2020

Baru mengenal pytorch, saya ingin tahu apakah ini bisa menjadi solusi :)

Misalkan Model melekat dari torch.nn.module,

untuk mengatur ulang ke nol:

dic = Model.state_dict()
for k in dic:
    dic[k] *= 0
Model.load_state_dict(dic)
del(dic)

untuk mengatur ulang secara acak

dic = Model.state_dict()
for k in dic:
    dic[k] = torch.randn(dic[k].size())  
Model.load_state_dict(dic)
del(dic)
person Jiayi Pan    schedule 19.07.2021