ข้อผิดพลาดขนาดอินพุตบนเครือข่ายหนาแน่นตามลำดับด้วย Keras

นี่เป็นคำถามที่ยาวเพราะฉันพยายามอธิบายให้มากขึ้นเท่าที่จะทำได้เพราะเป็นปัญหาที่เกิดซ้ำสำหรับฉันและฉันไม่เข้าใจจริงๆ ดังนั้นขอขอบคุณที่สละเวลาอ่านฉัน

ฉันต้องการสร้างโมเดลที่มีความหนาแน่นต่อเนื่องซึ่งใช้เป็นรายการอินพุตที่มีมิติดังนี้:

[batch_size, ข้อมูล_มิติ]

ดังนั้นฉันจึงกำหนดเครือข่ายของฉันดังนี้:

ModelDense = Sequential()

ModelDense.add(Dense(380, input_shape=(None,185), activation='elu', kernel_initializer='glorot_normal'))
ModelDense.add(Dense(380, activation='elu', kernel_initializer='glorot_normal'))
ModelDense.add(Dense(380, activation='elu', kernel_initializer='glorot_normal'))
ModelDense.add(Dense(7, activation='elu', kernel_initializer='glorot_normal'))
optimizer = tf.keras.optimizers.Adam(lr=0.00025)

ModelDense.compile(loss='mean_squared_error', optimizer=optimizer, metrics=['accuracy'])

แต่เมื่อฉันใช้เครือข่ายนี้โดยมีอินพุตที่มีรูปร่างดังนี้: (1, 185) ฉันได้รับข้อผิดพลาด:

เกิดข้อผิดพลาดเมื่อตรวจสอบอินพุต: คาดว่าหนาแน่น_อินพุตจะมี 3 มิติ แต่มีอาร์เรย์ที่มีรูปร่าง (185, 1)

อย่าถามฉันว่าทำไมฉันถึงบอกว่ารูปร่างเวกเตอร์ของฉันคือ (1, 185) และในข้อความแสดงข้อผิดพลาดที่เราเห็น (185, 1) เพราะเมื่อฉันตรวจสอบรูปร่างอาร์เรย์ของฉันก่อนที่จะให้มันเป็นอินพุตในเครือข่ายของฉัน รูปร่างที่แสดงคือ (1, 185)

ตกลง ดังนั้นฉันจึงตรวจสอบบางหัวข้อแล้วพบ อันนี้ ซึ่งมีคำอธิบายว่า :

เลเยอร์หนาแน่นต้องการอินพุตเป็น (batch_size, input_size) หรือ (batch_size, ทางเลือก,..., ทางเลือก, input_size)

นั่นคือสิ่งที่ฉันทำใช่ไหม? แต่ฉันก็เห็นว่า:

รูปร่างใน Keras:

...

ดังนั้น แม้ว่าคุณจะใช้ input_shape=(50,50,3) เมื่อ keras ส่งข้อความถึงคุณ หรือเมื่อคุณพิมพ์สรุปโมเดล มันก็จะแสดง (ไม่มี,50,50,3)

...

ดังนั้น เมื่อกำหนดรูปร่างอินพุต คุณจะมองข้ามขนาดแบตช์: input_shape=(50,50,3)

ตกลง ! มาลองตอนนี้ฉันกำหนดเลเยอร์อินพุตของฉันดังนี้:

ModelDense.add(Dense(380, input_shape=(185,), activation='elu', kernel_initializer='glorot_normal'))

เมื่อฉันทำ model.summary() :

_________________________________________________________________ ชั้น (ชนิด) พารามิเตอร์รูปร่างเอาท์พุต # ========================================= ======================== หนาแน่น (หนาแน่น) (ไม่มี, 380) 70680 _________________________________________________________________ หนาแน่น_1 (หนาแน่น) (ไม่มี, 380) 144780 _________________________________________________________________ หนาแน่น_2 (หนาแน่น) (ไม่มี, 380) 144780 _________________________________________________________________ หนาแน่น_3 (หนาแน่น) (ไม่มี, 7) 2667 ================================ ================================ พารามิเตอร์ทั้งหมด: 362,907 พารามิเตอร์ที่ฝึกได้: 362,907 พารามิเตอร์ที่ไม่สามารถฝึกได้: 0


โอเค ฉันคิดว่านั่นคือสิ่งที่ฉันต้องการ แต่เมื่อฉันให้อาร์เรย์ THE SAME เป็นอินพุต ตอนนี้ฉันได้รับข้อผิดพลาด:

ValueError: เกิดข้อผิดพลาดเมื่อตรวจสอบอินพุต: คาดว่าหนาแน่น_อินพุตจะมีรูปร่าง (185,) แต่มีอาร์เรย์ที่มีรูปร่าง (1,)

ฉันสับสน ฉันเข้าใจผิดอะไร

_________แก้ไข__________ :

ฟังก์ชั่นการทำนาย:

def predict(dense_model, state, action_size, epsilon):

    alea = np.random.rand()

    # DEBUG
    print(state)
    print(np.array(state).shape)

    output = dense_model.predict(state)

    if (epsilon > alea):
        action = random.randint(1, action_size) - 1
        flag_alea = True

    else:
        action = np.argmax(output)
        flag_alea = False

    return output, action, flag_alea

บรรทัดที่ฉันใช้ฟังก์ชันของฉัน:

Qs, action, flag_alea = predict(Dense_model, [state], ACTION_SIZE, Epsilon)

ผลลัพธ์ที่แน่นอนของการพิมพ์ 'DEBUG' ของฉัน:

[[0.0, 0.0, 0.0, 0.12410027302060064, 0.0, 0.0, 0.0, 0.0, 0.0, 0.18851780241253108, 0.0, 0.0, 0.2863141820958198, 0.0, 0.07328154770628756, 0.418848167539267, 0.07328154770628756, 0.2094240837696335, 0.42857142857142855, 0.0, 0.12410027302060064, 0.0, 0.0, 0.0, 0.0, 0.263306220774655, 0.14740566037735847, 0.40346984062941293, 0.675310642895732, 0.0, 0.0, 0.0, 0.0, 0.07328154770628756, 0.0, 0.4396892862377253, 0.0, 0.42857142857142855, 0.0, 0.12410027302060064, 0.08759635599159075, 0.0, 0.1401927621025243, 0.6755559204272007, 0.0, 0.0, 0.11564568886156315, 0.4051863857374392, 0.0, 0.0, 0.19087612139721322, 0.0, 0.07328154770628756, 0.6282722513089005, 0.14656309541257512, 0.10471204188481675, 0.42857142857142855, 0.0, 0.12410027302060064, 0.0, 0.0, 0.0, 0.0, 0.0974621385076755, 0.0, 0.0, 0.675310642895732, 0.0, 0.0, 0.0, 0.09543806069860661, 0.07328154770628756, 0.10471204188481675, 0.5129708339440129, 0.5233396901920598, 0.42857142857142855, 0.0, 0.0, 0.0, 0.0, 0.5528187746700128, 0.6755564266434103, 0.0, 0.0, 0.10086746015735323, 0.1350621285791464, 0.0, 0.0, 0.0, 0.0, 0.14891426591693724, 0.5166404112353377, 0.14656309541257512, 0.10471204188481675, 0.42857142857142855, 0.00846344605088234, 0.012550643645226955, 0.0, 0.0, 0.004527776502072811, 0.0, 0.001294999849051237, 0.019391579553484917, 0.02999694086611271, 0.0026073455810546875, 0.0, 0.0, 0.016546493396162987, 0.024497902020812035, 0.00018889713101089, 0.0, 0.005568447522819042, 0.0, 0.007975691929459572, 0.01434263214468956, 0.0, 6.733229383826256e-05, 0.0012099052546545863, 0.0, 0.0001209513284265995, 0.01868056133389473, 0.025530844926834106, 0.004079729784280062, 0.0, 0.0, 0.01332627609372139, 0.026645798236131668, 0.0, 0.0, 0.007684763520956039, 0.0, 0.010554256848990917, 0.007236589677631855, 0.0013368092477321625, 0.000697580398991704, 0.00213554291985929, 0.0, 0.0021772112231701612, 0.012761476449668407, 0.015171871520578861, 0.001512336079031229, 0.0, 0.0, 0.008273545652627945, 0.01777557097375393, 0.006600575987249613, 0.0, 0.007174563594162464, 0.0, 0.004660750739276409, 0.009024208411574364, 0.0, 0.0014235835988074541, 0.0, 0.0, 0.0, 0.008785379119217396, 0.010602384805679321, 0.0024691042490303516, 0.0, 0.0, 0.003091508522629738, 0.0120345214381814, 0.003123666625469923, 0.0, 0.005664713680744171, 0.0, 0.004825159907341003, 0.0034197410568594933, 0.0030767947901040316, 0.004110954236239195, 0.0, 0.0, 0.001896441332064569, 0.002400417113676667, 0.0012791997287422419, 0.0, 0.0, 0.0, 0.0021027529146522284, 0.006922871805727482, 0.004868669901043177, 0.0, 7.310241926461458e-05, 0.0]]

(1, 185)

_________แก้ไข2__________ :

การติดตามข้อผิดพลาด:

ไฟล์ ".!Qltrain.py", บรรทัด 360, ใน Qs, การกระทำ, flag_alea = ทำนาย (Dense_model, [state], ACTION_SIZE, Epsilon) ไฟล์ ".\Lib\Core.py", บรรทัด 336, ในการทำนายเอาต์พุต =หนาแน่น_model .predict(state) ไฟล์ "C:\Users\Odeven\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\engine\training.py", บรรทัด 1,096, ในการทำนาย x, check_steps=True, step_name='steps', step=steps) ไฟล์ "C:\Users\Odeven\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\engine\training.py ", บรรทัด 2382 ใน _standardize_user_dataข้อยกเว้น_prefix='input') ไฟล์ "C:\Users\Odeven\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\engine\training_utils.py" บรรทัด 362 ใน standardize_input_data ' แต่มีอาร์เรย์ที่มีรูปร่าง ' + str (data_shape)) ValueError: เกิดข้อผิดพลาดเมื่อตรวจสอบอินพุต: คาดว่าหนาแน่น_input จะมีรูปร่าง (185,) แต่มีอาร์เรย์ที่มีรูปร่าง (1,)

หากคุณตรวจสอบ 3 บรรทัดแรก คุณจะเห็นว่ารหัสที่ข้อผิดพลาดเกิดขึ้นคือรหัสที่ฉันเพิ่มในการแก้ไขครั้งแรก

_______ตัวอย่างที่ชัดเจน_______

เนื้อหาของ test.py:

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import random
import numpy as np

ModelDense = Sequential()

ModelDense.add(Dense(380, input_shape=(185,), activation='elu', kernel_initializer='glorot_normal'))
ModelDense.add(Dense(380, activation='elu', kernel_initializer='glorot_normal'))
ModelDense.add(Dense(380, activation='elu', kernel_initializer='glorot_normal'))
ModelDense.add(Dense(7, activation='elu', kernel_initializer='glorot_normal'))
optimizer = tf.keras.optimizers.Adam(lr=0.00025)

ModelDense.compile(loss='mean_squared_error', optimizer=optimizer, metrics=['accuracy'])


ModelDense.summary()



def predict(dense_model, state, action_size, epsilon):

    alea = np.random.rand()

    print(state)
    print(np.array(state).shape)

    dense_model.summary()

    output = dense_model.predict(state)

    if (epsilon > alea):
        action = random.randint(1, action_size) - 1
        flag_alea = True

    else:
        action = np.argmax(output)
        flag_alea = False

    return output, action, flag_alea



state = []
state.append([np.random.rand()] * 185)
output, ac, flag = predict(ModelDense, state, 7, 0.0)

print(output)

เอาท์พุทที่สมบูรณ์:

เปลี่ยนสิ่งนี้:


person Xeyes    schedule 10.07.2019    source แหล่งที่มา
comment
ฉันเพิ่มฟังก์ชันการคาดเดาลงในโพสต์ของฉัน ฉันยังเพิ่มบรรทัดที่ฉันใช้และข้อความแก้ไขข้อบกพร่องที่ฉันมีเมื่อฉันพิมพ์มิติข้อมูล   -  person Dr. Snoopy    schedule 10.07.2019
comment
คุณแน่ใจหรือว่านี่เป็นส่วนหนึ่งของโค้ดที่ทำให้เกิดข้อผิดพลาด บางทีคุณอาจมีอาร์เรย์ (185,1) อยู่ที่อื่นและมีข้อผิดพลาดอยู่ที่นั่น สำหรับฉันยังไม่ชัดเจนว่าข้อผิดพลาดและรหัสที่คุณรวมไว้เกี่ยวข้องกันอย่างไร   -  person Xeyes    schedule 10.07.2019
comment
ฉันแน่ใจอย่างสมบูรณ์เกี่ยวกับโค้ดที่ดำเนินการ ฉันเพิ่มการสืบค้นย้อนกลับของข้อผิดพลาด ฉันรู้ว่ามันแปลกจริงๆ นั่นคือสาเหตุที่ฉันไม่เข้าใจข้อผิดพลาดนี้จริงๆ   -  person Dr. Snoopy    schedule 10.07.2019
comment
จากนั้นสร้างตัวอย่างที่สมบูรณ์ในตัวเองซึ่งทำให้เกิดข้อผิดพลาดอีกครั้งและเราสามารถทำงานได้   -  person Xeyes    schedule 10.07.2019
comment
ตกลง ฉันเพิ่งสร้างไฟล์ชื่อ test.py โดยที่ฉันเขียนเฉพาะโค้ดสำคัญเท่านั้น และฉันได้รับข้อผิดพลาดเดียวกัน ฉันเพิ่มเนื้อหาของไฟล์นี้ในการแก้ไขครั้งที่สาม + การติดตามข้อผิดพลาด   -  person Dr. Snoopy    schedule 10.07.2019
comment
ให้เราสนทนาต่อในการแชท   -  person Xeyes    schedule 10.07.2019
comment
โอเค ขอบคุณมาก ! นั่นใช้ได้ดี ! ฉันพบว่าน่าประหลาดใจมากที่ keras ตีความรายการไม่ถูกต้อง แต่โอเค ฉันจะแปลงรายการของฉันทุกครั้ง ^^ ขอบคุณที่สละเวลาอ่านโพสต์ยาว ๆ เช่นนี้   -  person Xeyes    schedule 10.07.2019


คำตอบ (2)


เป็นสิ่งนี้:

output = dense_model.predict(state)

ดูเหมือนว่า keras จะสับสนหากคุณผ่านรายการธรรมดาเพื่อคาดเดาและอาจไม่ทำสิ่งที่คุณต้องการ ด้วยวิธีนี้คุณจึงมั่นใจได้ว่า state เป็นอาร์เรย์ที่มีตัวเลขของรูปร่างที่คุณคาดหวัง

output = dense_model.predict(np.array(state))

ฉันไม่ใช่มืออาชีพของ Keras แต่ฉันคิดว่าอย่างที่ @Matias Valdenegro พูดและเนื่องจากอัลกอริธึมกำลังจะทำการคูณเมทริกซ์ผ่านเครือข่าย คุณจึงคาดหวังให้คุณให้อาร์เรย์

person Dr. Snoopy    schedule 10.07.2019
comment
เมื่อฉันทำสิ่งที่คุณอธิบายบนคอนเทนเนอร์ 'ไม่ใช่อาร์เรย์' ฉันได้รับข้อผิดพลาดต่อไปนี้: ValueError: ข้อผิดพลาดเมื่อตรวจสอบอินพุต: คาดว่าหนาแน่น_อินพุตจะมี 2 มิติ แต่มีอาร์เรย์ที่มีรูปร่าง (1, 1, 185) ... ดังนั้น ฉันคิดว่าทางออกที่ดีที่สุดคือแปลงรายการของฉันเป็นอาร์เรย์ :) - person Xeyes; 10.07.2019

หากคุณมีสถานะเดียวที่คุณต้องการคาดการณ์ คุณสามารถขยายขอบเขตข้อมูลของคุณได้ดังนี้:

คุณแน่ใจหรือไม่ว่าข้อมูลการฝึกของคุณมีรูปร่าง (1, 158)? ข้อผิดพลาดทั้งสองชี้ว่ารูปร่างจริงคือ (158, 1)

state = np.expand_dims(state, axis=0)
person Adakor    schedule 10.07.2019
comment
_________________________________________________________________ ชั้น (ชนิด) พารามิเตอร์รูปร่างเอาท์พุต # ========================================= ======================== หนาแน่น (หนาแน่น) (ไม่มี, 380) 70680 _________________________________________________________________ หนาแน่น_1 (หนาแน่น) (ไม่มี, 380) 144780 _________________________________________________________________ หนาแน่น_2 (หนาแน่น) (ไม่มี, 380) 144780 _________________________________________________________________ หนาแน่น_3 (หนาแน่น) (ไม่มี, 7) 2667 ================================ ================================ พารามิเตอร์ทั้งหมด: 362,907 พารามิเตอร์ที่ฝึกได้: 362,907 พารามิเตอร์ที่ไม่สามารถฝึกได้: 0 __________________________________________________________________ [[0.11966889292971739 , 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.119668892929 71739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.119668892 92971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.119668 89292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.119 66889292971739, 0.11966889292971739 , 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.119668892929 71739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.119668892 92971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.119668 89292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.119 66889292971739, 0.11966889292971739 , 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.119668892929 71739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.119668892 92971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.119668 89292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.119 66889292971739, 0.11966889292971739 , 0.11966889292971739, 0. 11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.1196688929297173 9, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.1196688929297 1739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.1196688929 2971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.1196688 9292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971 739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292 971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889 292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966 889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971 739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292 971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889 292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966 889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971 739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292 971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889 292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966 889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971739, 0.11966889292971 739, 0.11966889292971739, 0.11966889292971739]] (1, 185) __________________________________________________________________ เลเยอร์ (ประเภท) พารามิเตอร์รูปร่างเอาต์พุต # =================== ============================================== หนาแน่น (หนาแน่น) (ไม่มี, 380) 70680 _________________________________________________________________ หนาแน่น_1 (หนาแน่น) (ไม่มี, 380) 144780 _________________________________________________________________ หนาแน่น_2 (หนาแน่น) (ไม่มี, 380) 144780 _________________________________________________________________ หนาแน่น_3 (หนาแน่น) (ไม่มี, 7) 2667 =========== ================================================== ==== พารามิเตอร์ทั้งหมด: 362,907 พารามิเตอร์ที่ฝึกได้: 362,907 พารามิเตอร์ที่ไม่สามารถฝึกได้: 0 __________________________________________________________________ Traceback (การโทรล่าสุดครั้งล่าสุด): ไฟล์ " \test.py", บรรทัดที่ 47, ในเอาต์พุต, ac, flag = ทำนาย (ModelDense, state, 7, 0.0) ไฟล์ ".\test.py", บรรทัด 31, ในเอาต์พุตคาดการณ์ =หนาแน่น_model.predict(state) ไฟล์ " C:\Users\Odeven\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\engine\training.py", บรรทัด 1,096, ในการทำนาย x, check_steps=True, step_name=' ขั้นตอน ', ขั้นตอน = ขั้นตอน) ไฟล์ "C:\Users\Odeven\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\engine\training.py", บรรทัด 2382 ใน _standardize_user_data ข้อยกเว้น_prefix='input') ไฟล์ "C:\Users\Odeven\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\engine\training_utils.py", บรรทัด 362, ใน standardize_input_data ' แต่มีอาร์เรย์ที่มีรูปร่าง ' + str(data_shape)) ValueError: เกิดข้อผิดพลาดเมื่อตรวจสอบอินพุต: คาดว่าหนาแน่น_input จะมีรูปร่าง (185,) แต่มีอาร์เรย์ที่มีรูปร่าง (1,) - person Xeyes; 10.07.2019