Produk tensor lucu Numpy/PyTorch

Saya memiliki parameter tensor obor 4 dimensi yang didefinisikan seperti ini:

nn.parameter.Parameter(data=torch.Tensor((13,13,13,13)), requires_grad=True)

dan empat tensor dengan redup (batch_size,13) (atau satu tensor dengan redup (batch_size,4,13)). Saya ingin mendapatkan tensor dengan dims (batch_size) sama dengan rumus di akhir gambar ini : [EDIT: Saya membuat kesalahan pada gambar pertama, saya sudah memperbaikinya] masukkan deskripsi gambar di sini Saya telah melihat fungsi tensordot di dokumentasi obor, namun saya tidak dapat membuatnya berfungsi sendiri.


person Jogima_cyber    schedule 19.01.2021    source sumber
comment
Untuk memastikan, Anda mengindeks keluaran B dengan i tetapi ini berbeda dari variabel bisu tempat Anda melakukan penjumlahan terluar, bukan?   -  person trialNerror    schedule 20.01.2021
comment
Apakah result = (A[None, :, :, :, :] * X[:, :, None, None, None] * Y[:, None, :, None, None] * Z[:, None, None, :, None] * T[:, None, None, None, :]).flatten(1).sum(dim=1) yang Anda inginkan? Jika demikian saya dapat memposting jawabannya dengan penjelasan.   -  person jodag    schedule 20.01.2021
comment
Satu-satunya alasan saya bingung dan tidak memposting sebagai jawaban adalah penggunaan subskrip i di sisi kiri dan kanan persamaan yang Anda posting.   -  person jodag    schedule 20.01.2021
comment
@trialNerror Saya telah memperbaiki gambar saya, dan Anda benar, saya harus menambahkan variabel b untuk elemen batch.   -  person Jogima_cyber    schedule 20.01.2021
comment
@jodag Saya tidak yakin ini yang saya coba lakukan, karena X, Y, Z dan T dalam persamaan Anda adalah tensor redup 5 tetapi persamaan saya adalah tensor redup 2.   -  person Jogima_cyber    schedule 20.01.2021
comment
Jika A adalah tensor dims 3, maka saya berhasil melakukannya dengan torch.bmm(torch.unsqueeze(z,2),torch.bmm(torch.unsqueeze(y,1),torch.transpose(torch.matmul( x,A),0,1))).jumlah(sumbu=2).jumlah(sumbu=1)   -  person Jogima_cyber    schedule 20.01.2021
comment
@Jogima_cyber Tidak A adalah 4 dimensi dan X, Y, Z, dan T adalah 2 dimensi seperti yang Anda tunjukkan pada gambar Anda. Pengindeksan None hanya digunakan untuk memasukkan dimensi kesatuan seperti di numpy. Ini adalah siaran yang setara dengan ekspresi einsum dari jawaban Shai.   -  person jodag    schedule 20.01.2021


Jawaban (1)


kapan pun Anda memiliki produk tensor yang lucu torch.einsum (atau numpy.einsum) adalah teman Anda:

batch_size = 5
A = torch.rand(13, 13, 13, 13)
a = torch.rand(batch_size, 13)
b = torch.rand(batch_size, 13)
c = torch.rand(batch_size, 13)
d = torch.rand(batch_size, 13)
B = torch.einsum('ijkl,bi,bj,bk,bl->b', A, a, b, c, d)
person Shai    schedule 20.01.2021
comment
Hebat, terima kasih banyak, tidak mengetahui fungsi ini. Apakah bisa di GPU juga? - person Jogima_cyber; 20.01.2021
comment
@Jogima_cyber torch.einsum adalah bagian dari operasi dasar pytorch. Ini tidak hanya berfungsi pada cpu/gpu tetapi juga menyebarkan gradien ke tensor masukan. - person Shai; 20.01.2021
comment
Begitu hebatnya, perhitungan tersembunyi yang begitu rumit. - person Jogima_cyber; 20.01.2021