Ошибка входного размера в последовательной плотной сети с Keras

Это длинный вопрос, потому что я пытаюсь объяснить свою проблему, потому что это повторяющаяся проблема для меня, и я действительно не понимаю, поэтому спасибо, что нашли время, чтобы прочитать меня

Я хочу создать последовательную плотную модель, которая принимает в качестве входного списка такой размер:

[размер_пакета, размер_данных]

Итак, я определил свою сеть так:

ModelDense = Sequential()

ModelDense.add(Dense(380, input_shape=(None,185), activation='elu', kernel_initializer='glorot_normal'))
ModelDense.add(Dense(380, activation='elu', kernel_initializer='glorot_normal'))
ModelDense.add(Dense(380, activation='elu', kernel_initializer='glorot_normal'))
ModelDense.add(Dense(7, activation='elu', kernel_initializer='glorot_normal'))
optimizer = tf.keras.optimizers.Adam(lr=0.00025)

ModelDense.compile(loss='mean_squared_error', optimizer=optimizer, metrics=['accuracy'])

но когда я использую эту сеть с вводом такой формы: (1, 185), я получил ошибку:

Ошибка при проверке ввода: ожидалось, что плотный_ввод будет иметь 3 измерения, но получил массив с формой (185, 1)

Не спрашивайте меня, почему я сказал, что моя векторная фигура (1, 185), а в сообщении об ошибке мы видим (185, 1), потому что, когда я проверяю свою форму массива непосредственно перед тем, как передать ее в качестве входных данных в мою сеть, показана форма (1, 185)

Хорошо, поэтому я проверил несколько тем и нашел эту, в котором поясняется, что:

Для плотных слоев требуются входные данные как (batch_size, input_size) или (batch_size, optional, ..., optional, input_size)

Так вот что я сделал, не так ли? Но я также видел это:

Формы в Керасе:

...

Итак, даже если вы использовали input_shape = (50,50,3), когда keras отправляет вам сообщения или когда вы распечатываете сводку модели, она будет отображать (None, 50,50,3)

...

Итак, при определении формы ввода вы игнорируете размер пакета: input_shape = (50,50,3)

Ok ! давайте попробуем, теперь я определил свой входной слой следующим образом:

ModelDense.add(Dense(380, input_shape=(185,), activation='elu', kernel_initializer='glorot_normal'))

Когда я делаю model.summary ():

_________________________________________________________________ Layer (type) Output Shape Param # =========================================== ======================== плотный (плотный) (нет, 380) 70680 _________________________________________________________________ плотный_1 (плотный) (нет, 380) 144780 _________________________________________________________________ плотный_2 (плотный) (Нет, 380) 144780 _________________________________________________________________ плотный_3 (Плотный) (Нет, 7) 2667 =================================== ================================ Всего параметров: 362 907 обучаемых параметров: 362 907 Не обучаемых параметров: 0


Хорошо, я думаю, что это то, что я хочу, но когда я даю массив ТО ЖЕ в качестве ввода, я получаю сообщение об ошибке:

ValueError: Ошибка при проверке ввода: ожидалось, что density_input будет иметь форму (185,), но получил массив с формой (1,)

Я в замешательстве, что я не понимаю?

_________EDIT__________:

Функция прогноза:

def predict(dense_model, state, action_size, epsilon):

    alea = np.random.rand()

    # DEBUG
    print(state)
    print(np.array(state).shape)

    output = dense_model.predict(state)

    if (epsilon > alea):
        action = random.randint(1, action_size) - 1
        flag_alea = True

    else:
        action = np.argmax(output)
        flag_alea = False

    return output, action, flag_alea

Строка, в которой я использую свою функцию:

Qs, action, flag_alea = predict(Dense_model, [state], ACTION_SIZE, Epsilon)

Точный результат моей печати DEBUG:

[[0.0, 0.0, 0.0, 0.12410027302060064, 0.0, 0.0, 0.0, 0.0, 0.0, 0.18851780241253108, 0.0, 0.0, 0.2863141820958198, 0.0, 0.07328154770628756, 0.418848167539267, 0.07328154770628756, 0.2094240837696335, 0.42857142857142855, 0.0, 0.12410027302060064, 0.0, 0.0, 0.0, 0.0, 0.263306220774655, 0.14740566037735847, 0.40346984062941293, 0.675310642895732, 0.0, 0.0, 0.0, 0.0, 0.07328154770628756, 0.0, 0.4396892862377253, 0.0, 0.42857142857142855, 0.0, 0.12410027302060064, 0.08759635599159075, 0.0, 0.1401927621025243, 0.6755559204272007, 0.0, 0.0, 0.11564568886156315, 0.4051863857374392, 0.0, 0.0, 0.19087612139721322, 0.0, 0.07328154770628756, 0.6282722513089005, 0.14656309541257512, 0.10471204188481675, 0.42857142857142855, 0.0, 0.12410027302060064, 0.0, 0.0, 0.0, 0.0, 0.0974621385076755, 0.0, 0.0, 0.675310642895732, 0.0, 0.0, 0.0, 0.09543806069860661, 0.07328154770628756, 0.10471204188481675, 0.5129708339440129, 0.5233396901920598, 0.42857142857142855, 0.0, 0.0, 0.0, 0.0, 0.5528187746700128, 0.6755564266434103, 0.0, 0.0, 0.10086746015735323, 0.1350621285791464, 0.0, 0.0, 0.0, 0.0, 0.14891426591693724, 0.5166404112353377, 0.14656309541257512, 0.10471204188481675, 0.42857142857142855, 0.00846344605088234, 0.012550643645226955, 0.0, 0.0, 0.004527776502072811, 0.0, 0.001294999849051237, 0.019391579553484917, 0.02999694086611271, 0.0026073455810546875, 0.0, 0.0, 0.016546493396162987, 0.024497902020812035, 0.00018889713101089, 0.0, 0.005568447522819042, 0.0, 0.007975691929459572, 0.01434263214468956, 0.0, 6.733229383826256e-05, 0.0012099052546545863, 0.0, 0.0001209513284265995, 0.01868056133389473, 0.025530844926834106, 0.004079729784280062, 0.0, 0.0, 0.01332627609372139, 0.026645798236131668, 0.0, 0.0, 0.007684763520956039, 0.0, 0.010554256848990917, 0.007236589677631855, 0.0013368092477321625, 0.000697580398991704, 0.00213554291985929, 0.0, 0.0021772112231701612, 0.012761476449668407, 0.015171871520578861, 0.001512336079031229, 0.0, 0.0, 0.008273545652627945, 0.01777557097375393, 0.006600575987249613, 0.0, 0.007174563594162464, 0.0, 0.004660750739276409, 0.009024208411574364, 0.0, 0.0014235835988074541, 0.0, 0.0, 0.0, 0.008785379119217396, 0.010602384805679321, 0.0024691042490303516, 0.0, 0.0, 0.003091508522629738, 0.0120345214381814, 0.003123666625469923, 0.0, 0.005664713680744171, 0.0, 0.004825159907341003, 0.0034197410568594933, 0.0030767947901040316, 0.004110954236239195, 0.0, 0.0, 0.001896441332064569, 0.002400417113676667, 0.0012791997287422419, 0.0, 0.0, 0.0, 0.0021027529146522284, 0.006922871805727482, 0.004868669901043177, 0.0, 7.310241926461458e-05, 0.0]]

(1, 185)

_________EDIT2__________:

Отслеживание ошибок:

Файл ".! Qltrain.py", строка 360, в Qs, action, flag_alea = прогноз (Dense_model, [состояние], ACTION_SIZE, Epsilon) Файл ". \ Lib \ Core.py", строка 336, в прогнозе output = density_model .predict (состояние) Файл "C: \ Users \ Odeven \ AppData \ Local \ Programs \ Python \ Python37 \ lib \ site-packages \ tensorflow \ python \ keras \ engine \ training.py", строка 1096, в прогнозе x, check_steps = True, steps_name = 'steps', steps = steps) Файл "C: \ Users \ Odeven \ AppData \ Local \ Programs \ Python \ Python37 \ lib \ site-packages \ tensorflow \ python \ keras \ engine \ training.py ", строка 2382, в _standardize_user_data exception_prefix = 'input') Файл" C: \ Users \ Odeven \ AppData \ Local \ Programs \ Python \ Python37 \ lib \ site-packages \ tensorflow \ python \ keras \ engine \ training_utils.py " , строка 362, в standardize_input_data ', но получил массив с формой' + str (data_shape)) ValueError: Ошибка при проверке ввода: ожидалось, что density_input будет иметь форму (185,), но получил массив с формой (1,)

Если вы проверите первые 3 строки, вы увидите, что код, из которого исходит erorr, - это код, который я добавил в свое первое редактирование.

_______ самодостаточный пример _______

Содержание test.py:

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import random
import numpy as np

ModelDense = Sequential()

ModelDense.add(Dense(380, input_shape=(185,), activation='elu', kernel_initializer='glorot_normal'))
ModelDense.add(Dense(380, activation='elu', kernel_initializer='glorot_normal'))
ModelDense.add(Dense(380, activation='elu', kernel_initializer='glorot_normal'))
ModelDense.add(Dense(7, activation='elu', kernel_initializer='glorot_normal'))
optimizer = tf.keras.optimizers.Adam(lr=0.00025)

ModelDense.compile(loss='mean_squared_error', optimizer=optimizer, metrics=['accuracy'])


ModelDense.summary()



def predict(dense_model, state, action_size, epsilon):

    alea = np.random.rand()

    print(state)
    print(np.array(state).shape)

    dense_model.summary()

    output = dense_model.predict(state)

    if (epsilon > alea):
        action = random.randint(1, action_size) - 1
        flag_alea = True

    else:
        action = np.argmax(output)
        flag_alea = False

    return output, action, flag_alea



state = []
state.append([np.random.rand()] * 185)
output, ac, flag = predict(ModelDense, state, 7, 0.0)

print(output)

Полный вывод:

Измените это:


person Xeyes    schedule 10.07.2019    source источник
comment
Я добавляю свою функцию прогнозирования в свой пост, я также добавляю строку, в которой я ее использую, и сообщение отладки, которое у меня появляется, когда я печатаю размер   -  person Dr. Snoopy    schedule 10.07.2019
comment
Вы уверены, что именно в этой части кода возникает ошибка? Возможно, у вас есть массив (185,1) где-то еще, и там есть ошибка. Мне непонятно, как связаны ошибка и включенный вами код.   -  person Xeyes    schedule 10.07.2019
comment
Я полностью уверен в выполненном коде, я добавляю трассировку ошибки, я знаю, что это действительно странно, поэтому я действительно не понимаю эту ошибку   -  person Dr. Snoopy    schedule 10.07.2019
comment
Затем создайте самодостаточный пример, воспроизводящий ошибку, и мы можем запустить его.   -  person Xeyes    schedule 10.07.2019
comment
Хорошо, я только что создал файл с именем test.py, в котором я написал только важный код, и я получаю ту же ошибку, я добавляю содержимое этого файла в третьем редактировании + трассировка ошибки   -  person Dr. Snoopy    schedule 10.07.2019
comment
Позвольте нам продолжить это обсуждение в чате.   -  person Xeyes    schedule 10.07.2019
comment
Хорошо, спасибо большое! это сработало! Я нахожу очень удивительным, что keras неправильно интерпретирует список, но хорошо, теперь я всегда конвертирую свой список ^^. Спасибо, что нашли время прочитать такой длинный пост   -  person Xeyes    schedule 10.07.2019


Ответы (2)


В это:

output = dense_model.predict(state)

Кажется, что keras сбивается с толку, если вы передаете простой список для предсказания и можете не делать то, что хотите, таким образом вы убедитесь, что state представляет собой массив с множеством чисел той формы, которую вы ожидаете.

output = dense_model.predict(np.array(state))

Я не профессионал в Keras, но я думаю, как сказал @Matias Valdenegro, и поскольку алгоритм будет выполнять умножение матриц через сеть, он ожидает, что вы предоставите массив.

person Dr. Snoopy    schedule 10.07.2019
comment
Когда я делаю то, что вы объяснили, в контейнере «не numpy array», я получил следующий erorr: ValueError: Ошибка при проверке ввода: ожидалось, что density_input будет иметь 2 измерения, но получил массив с формой (1, 1, 185) ... так что я думаю, что лучшим решением будет просто преобразовать мой список в массив :) - person Xeyes; 10.07.2019

Если у вас есть только одно состояние, для которого вы хотите выполнить прогноз, вы можете расширить тусклость ваших данных следующим образом:

Вы уверены, что ваш обучающий ввод имеет форму (1, 158)? Обе ошибки указывают на то, что фактическая форма (158, 1)

state = np.expand_dims(state, axis=0)
person Adakor    schedule 10.07.2019
comment
_________________________________________________________________ Layer (type) Output Shape Param # =========================================== ======================== плотный (плотный) (нет, 380) 70680 _________________________________________________________________ плотный_1 (плотный) (нет, 380) 144780 _________________________________________________________________ плотный_2 (плотный) (Нет, 380) 144780 _________________________________________________________________ плотный_3 (Плотный) (Нет, 7) 2667 =================================== ================================ Всего параметров: 362 907 Обучаемые параметры: 362 907 Необучаемые параметры: 0 _________________________________________________________________ [[0.11966889292971739 , 0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0,119668892929 71 739, 0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,1196 6889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0. 11966889292971739, +0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,1196688929297 1739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, 0,11966889292971739, +0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0,11966889292971739, 0,11966 889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739, +0,11966889292971739, +0,11966889292971739, 0,11966889292971739, +0,11966889292971739]] (1, 185) _________________________________________________________________ слой (тип) Выходная форма Param # ========= ================================================== == ================================================== ============== Итого в год rams: 362 907 Обучаемые параметры: 362 907 Необучаемые параметры: 0 _________________________________________________________________ Traceback (последний вызов последним): File ". \ test.py ", строка 47, в выходных данных, ac, flag = predict (ModelDense, state, 7, 0.0) File". \ test.py ", строка 31, в predict output = density_model.predict (state) File" C: \ Users \ Odeven \ AppData \ Local \ Programs \ Python \ Python37 \ lib \ site-packages \ tensorflow \ python \ keras \ engine \ training.py ", строка 1096, в прогнозе x, check_steps = True, steps_name = ' steps ', steps = steps) Файл "C: \ Users \ Odeven \ AppData \ Local \ Programs \ Python \ Python37 \ lib \ site-packages \ tensorflow \ python \ keras \ engine \ training.py", строка 2382, в _standardize_user_data exception_prefix = 'input') Файл "C: \ Users \ Odeven \ AppData \ Local \ Programs \ Python \ Python37 \ lib \ site-packages \ tensorflow \ python \ keras \ engine \ training_utils.py", строка 362, в standardize_input_data ' но получил массив с формой '+ str (data_shape)) ValueError: Ошибка при проверке ввода: ожидалось, что плотный_ввод будет иметь форму (185,), но получил массив с формой (1,) - person Xeyes; 10.07.2019