Memformat data dengan benar untuk jaringan saraf berulang lstm di R/mxnet

Saya ingin melatih jaringan saraf lstm menggunakan fungsi mx.lstm di paket R mxnet. Data saya terdiri dari n vektor fitur, vektor kelas berlabel, dan vektor waktu, seperti contoh tiruan ini yang mana X1, X2, X3 adalah fiturnya:

dat <- data.frame(
  X1 = rnorm(100, 1, sd = 1),
  X2 = rnorm(100, 2, sd = 1),
  X3 = rnorm(100, 3, sd = 1),
  class = sample(c(1,0), replace = T, 100),
  time =  seq(0.01,1,0.01))

Bantuan untuk mx.lstm menyatakan bahwa argumen train.data memerlukan "mx.io.DataIter atau list(data=R.array, label=R.array) The Training set".

Saya sudah mencoba ini:

library(mxnet)

# Convert dummy data into suitable format
trainDat <- list(data = array(c(dat$X1, dat$X2, dat$X3), dim = c(100,3)), 
label = array(dat[,4], dim = c(100,1)))

# Set the basic network parameters for the lstm (arbitrary for this example)
batch.size = 32
seq.len = 32
num.hidden = 16
num.embed = 16
num.lstm.layer = 1
num.round = 1
learning.rate = 0.1
wd = 0.00001
clip_gradient = 1
update.period = 1

# Run the model
model <- mx.lstm(train.data = trainDat,
             ctx=mx.cpu(),
             num.round=num.round, 
             update.period=update.period,
             num.lstm.layer=num.lstm.layer, 
             seq.len=seq.len,
             num.hidden=num.hidden, 
             num.embed=num.embed, 
             num.label=vocab,
             batch.size=batch.size, 
             input.size=vocab,
             initializer=mx.init.uniform(0.1), 
             learning.rate=learning.rate,
             wd=wd,
             clip_gradient=clip_gradient)

Yang mengembalikan "Kesalahan dalam mx.io.internal.arrayiter(as.array(data), as.array(label), unif.rnds, : basic_string::_M_replace_aux"

Ada contoh lstm di situs mxnet, tetapi data yang digunakan sangat berbeda dengan milik saya dan saya tidak dapat memahaminya.

http://mxnet.io/tutorials/r/charRnnModel.html

Jadi, pertanyaan saya adalah bagaimana cara mengubah data saya ke format yang sesuai untuk mx.lstm?


person RCW    schedule 23.12.2016    source sumber


Jawaban (1)


Saya mencoba mereproduksi kesalahan Anda dan mendapat pesan yang lebih detail:

Kesalahan di mx.io.internal.arrayiter(as.array(data), as.array(label), unif.rnds, : io.cc:50: Tampaknya X, y diteruskan dengan cara utama Row, MXNetR mengadopsi a konvensi utama kolom. Harap berikan transpos X sebagai gantinya

Saya memperbaiki kesalahan dengan meneruskan data dan memberi label array ke aperm().

trainDat <- list(data = aperm(array(c(dat$X1, dat$X2, dat$X3), dim = c(100,3))), label = aperm(array(dat[,4], dim = c(100,1))))
person lynguyen    schedule 27.12.2016