Смешной тензорный продукт Numpy / PyTorch

У меня есть 4-х мерный параметр тензора факела, определенный следующим образом:

nn.parameter.Parameter(data=torch.Tensor((13,13,13,13)), requires_grad=True)

и четыре тензора с dims (batch_size, 13) (или один тензор с dims (batch_size, 4,13)). Я хотел бы получить тензор с dims (batch_size), равный формуле в конце этого изображения: [EDIT: я сделал ошибку на первом рисунке, я исправил его]  введите описание изображения здесь Я видел в документации к факелу функцию tensordot, но мне не удается заставить ее работать самостоятельно.


person Jogima_cyber    schedule 19.01.2021    source источник
comment
Чтобы быть уверенным, вы индексируете свой вывод B с помощью i, но он отличается от переменной mute, для которой вы выполняете самую внешнюю сумму, верно?   -  person trialNerror    schedule 20.01.2021
comment
result = (A[None, :, :, :, :] * X[:, :, None, None, None] * Y[:, None, :, None, None] * Z[:, None, None, :, None] * T[:, None, None, None, :]).flatten(1).sum(dim=1) то, что вы хотите? Если да, я могу опубликовать ответ с объяснением.   -  person jodag    schedule 20.01.2021
comment
Единственная причина, по которой я смущен и не публикую в качестве ответа, - это использование индекса i в левой и правой частях уравнения, которое вы опубликовали.   -  person jodag    schedule 20.01.2021
comment
@trialNerror Я исправил свою картинку, и вы правы, мне пришлось добавить переменную b для пакетных элементов.   -  person Jogima_cyber    schedule 20.01.2021
comment
@jodag Я не уверен, что это то, что я пытаюсь сделать, поскольку X, Y, Z и T в вашем уравнении - это тензоры dims 5, а mines - тензоры dims 2.   -  person Jogima_cyber    schedule 20.01.2021
comment
Если A - тензор dims 3, то мне удается это сделать с torch.bmm (torch.unsqueeze (z, 2), torch.bmm (torch.unsqueeze (y, 1), torch.transpose (torch.matmul ( x, A), 0,1))). sum (ось = 2) .sum (ось = 1)   -  person Jogima_cyber    schedule 20.01.2021
comment
@Jogima_cyber Нет, A четырехмерный, а X, Y, Z и T двухмерный, как вы указали на своем изображении. Индексирование None просто используется для вставки единичных размеров, как в numpy. Это транслируемый эквивалент выражения einsum ответа Шая.   -  person jodag    schedule 20.01.2021


Ответы (1)


всякий раз, когда у вас есть забавный тензорный продукт torch.einsum (или numpy.einsum) ваш друг:

batch_size = 5
A = torch.rand(13, 13, 13, 13)
a = torch.rand(batch_size, 13)
b = torch.rand(batch_size, 13)
c = torch.rand(batch_size, 13)
d = torch.rand(batch_size, 13)
B = torch.einsum('ijkl,bi,bj,bk,bl->b', A, a, b, c, d)
person Shai    schedule 20.01.2021
comment
Отлично, спасибо, не знал об этой функции. На gpu тоже может работать? - person Jogima_cyber; 20.01.2021
comment
@Jogima_cyber torch.einsum является частью основных операций pytorch. Он не только работает с процессором / графическим процессором, но также распространяет градиенты на входные тензоры. - person Shai; 20.01.2021
comment
Такие великолепные, такие сложные скрытые вычисления. - person Jogima_cyber; 20.01.2021