Saya sedikit bingung bagaimana penerapan KL divergence, khususnya di Keras, tapi menurut saya pertanyaannya bersifat umum untuk aplikasi deep learning. Dalam kerasnya, fungsi kerugian KL didefinisikan seperti ini:
def kullback_leibler_divergence(y_true, y_pred):
y_true = K.clip(y_true, K.epsilon(), 1)
y_pred = K.clip(y_pred, K.epsilon(), 1)
return K.sum(y_true * K.log(y_true / y_pred), axis=-1)
Dalam model saya, y_true
dan y_pred
adalah matriks; setiap baris y_true
merupakan pengkodean one-hot untuk satu contoh pelatihan, dan setiap baris y_pred
merupakan keluaran model (distribusi probabilitas) untuk contoh tersebut.
Saya dapat menjalankan penghitungan divergensi KL ini pada pasangan baris mana pun dari y_true
dan y_pred
dan mendapatkan hasil yang diharapkan. Rata-rata hasil divergensi KL pada baris-baris tersebut sesuai dengan kerugian yang dilaporkan oleh Keras dalam riwayat pelatihan. Namun agregasi tersebut - menjalankan divergensi KL pada setiap baris dan mengambil mean - tidak terjadi dalam fungsi kerugian. Sebaliknya, saya memahami MAE atau MSE untuk menggabungkan seluruh contoh:
def mean_squared_error(y_true, y_pred):
return K.mean(K.square(y_pred - y_true), axis=-1)
Untuk divergensi KL, tidak sepenuhnya jelas bagi saya bahwa mengambil mean dari seluruh contoh adalah hal yang benar untuk dilakukan. Saya kira idenya adalah bahwa contoh-contoh tersebut adalah sampel acak dari distribusi sebenarnya, sehingga contoh-contoh tersebut harus muncul sebanding dengan probabilitasnya. Namun hal tersebut tampaknya memberikan asumsi yang cukup kuat tentang bagaimana data pelatihan dikumpulkan. Saya belum benar-benar melihat aspek ini (menggabungkan seluruh sampel dari kumpulan data) yang dibahas dalam perawatan online divergensi KL; Saya hanya melihat banyak redefinisi dari rumus dasarnya.
Jadi pertanyaan saya adalah:
Apakah interpretasi tentang apa yang dilakukan Keras untuk menghasilkan kerugian divergensi KL (yaitu rata-rata pada divergensi baris KL) benar?
Mengapa ini merupakan hal yang benar untuk dilakukan?
Dari perspektif implementasi, mengapa definisi fungsi kerugian di Keras tidak melakukan agregasi pada baris seperti yang dilakukan MAE atau MSE?