Моя DNN возвращает один и тот же прогноз для всех тестовых данных (tensorflow)

Я пытаюсь запустить этот код для своих данных для регрессии. Кажется, что сеть может предсказать первые тестовые данные, но все остальные предсказания такие же, как и первые. Первая функция генерирует случайные веса для инициализации. Количество предикторов равно 54, а количество выходов - 4. Вот мой код:

def init_weights(shape):
   weights = tf.random_uniform(shape, -2,2)
   return tf.Variable(weights)

def forwardprop(X, w, b, sig):
   if sig==1:
       yhat = tf.sigmoid(tf.add(tf.matmul(X, w),b))
   else:
       yhat = tf.add(tf.matmul(X, w),0.)
return yhat

def main(itr,starter_learning_rate):    

   x_size = train_X.shape[1]  
   h_size = 4
   y_size = train_y.shape[1]  

   X = tf.placeholder("float", shape = [None, x_size])
   y = tf.placeholder("float", shape = [None, y_size])

   w_1 = init_weights((x_size, h_size))
   b_1 =  tf.constant(1.)

   w_2 = init_weights((h_size, y_size))
   b_2 =  tf.constant(1.)

   yhat_1 = forwardprop(X, w_1, b_1, 1)    
   yhat =  forwardprop(yhat_1, w_2, b_2, 0)


   n_samples = train_X.shape[0]
   cost = tf.reduce_sum(tf.pow(yhat-y, 2))/(2*n_samples)
   updates = tf.train.GradientDescentOptimizer(starter_learning_rate).minimize(cost)


   sess = tf.Session()
   init = tf.global_variables_initializer()
   sess.run(init)

   for epoch in range(itr):
       sess.run(updates, feed_dict={X: train_X, y: train_y})
       train_err = train_y - sess.run(yhat, feed_dict={X: train_X, y: train_y})
       train_accuracy = np.mean(train_err**2)

       test_err  = test_y - sess.run(yhat, feed_dict={X: test_X, y: test_y})
       test_accuracy  =np.mean(test_err**2)


   print(sess.run(yhat, feed_dict={X: test_X, y: test_y}))
   sess.close()
if __name__ == '__main__':
   main(itr=10000,starter_learning_rate=0.001)

person Nima    schedule 19.02.2017    source источник


Ответы (2)


проверьте данные test_X, возможно, все записи text_X одинаковы.

person Allen Hong    schedule 20.02.2017

Нормализация вектора метки была решением.

person Nima    schedule 05.12.2017