DNN saya mengembalikan prediksi yang sama untuk semua data pengujian (tensorflow)

Saya mencoba menjalankan kode ini pada data saya untuk regresi. Tampaknya jaringan dapat memprediksi data pengujian pertama tetapi semua prediksi lainnya sama dengan yang pertama. Fungsi pertama menghasilkan bobot acak untuk inisialisasi. Jumlah prediktornya 54 dan jumlah keluarannya 4. Ini kode saya:

def init_weights(shape):
   weights = tf.random_uniform(shape, -2,2)
   return tf.Variable(weights)

def forwardprop(X, w, b, sig):
   if sig==1:
       yhat = tf.sigmoid(tf.add(tf.matmul(X, w),b))
   else:
       yhat = tf.add(tf.matmul(X, w),0.)
return yhat

def main(itr,starter_learning_rate):    

   x_size = train_X.shape[1]  
   h_size = 4
   y_size = train_y.shape[1]  

   X = tf.placeholder("float", shape = [None, x_size])
   y = tf.placeholder("float", shape = [None, y_size])

   w_1 = init_weights((x_size, h_size))
   b_1 =  tf.constant(1.)

   w_2 = init_weights((h_size, y_size))
   b_2 =  tf.constant(1.)

   yhat_1 = forwardprop(X, w_1, b_1, 1)    
   yhat =  forwardprop(yhat_1, w_2, b_2, 0)


   n_samples = train_X.shape[0]
   cost = tf.reduce_sum(tf.pow(yhat-y, 2))/(2*n_samples)
   updates = tf.train.GradientDescentOptimizer(starter_learning_rate).minimize(cost)


   sess = tf.Session()
   init = tf.global_variables_initializer()
   sess.run(init)

   for epoch in range(itr):
       sess.run(updates, feed_dict={X: train_X, y: train_y})
       train_err = train_y - sess.run(yhat, feed_dict={X: train_X, y: train_y})
       train_accuracy = np.mean(train_err**2)

       test_err  = test_y - sess.run(yhat, feed_dict={X: test_X, y: test_y})
       test_accuracy  =np.mean(test_err**2)


   print(sess.run(yhat, feed_dict={X: test_X, y: test_y}))
   sess.close()
if __name__ == '__main__':
   main(itr=10000,starter_learning_rate=0.001)

person Nima    schedule 19.02.2017    source sumber


Jawaban (2)


periksa data test_X, mungkin semua record text_X sama.

person Allen Hong    schedule 20.02.2017

Normalisasi vektor label adalah solusinya.

person Nima    schedule 05.12.2017