2016-03-21 18 views
8

Tôi đang cố gắng để chạy SimpleRNN này:Sai số của kích thước trên model.fit

model.add(SimpleRNN(init='uniform',output_dim=1,input_dim=len(pred_frame.columns))) 
model.compile(loss="mse", optimizer="sgd") 
model.fit(X=predictor_train, y=target_train, batch_size=len(pred_frame.index),show_accuracy=True) 

Lỗi này là trên model.fit, như bạn có thể xem dưới đây:

File "/Users/file.py", line 1496, in Pred 
model.fit(X=predictor_train, y=target_train, batch_size=len(pred_frame.index),show_accuracy=True) 
File "/Library/Python/2.7/site-packages/keras/models.py", line 581, in fit 
shuffle=shuffle, metrics=metrics) 
File "/Library/Python/2.7/site-packages/keras/models.py", line 239, in _fit 
outs = f(ins_batch) 
File "/Library/Python/2.7/site-packages/keras/backend/theano_backend.py", line 365, in __call__ 
return self.function(*inputs) 
File "/Library/Python/2.7/site-packages/theano/compile/function_module.py", line 513, in __call__ 
allow_downcast=s.allow_downcast) 
File "/Library/Python/2.7/site-packages/theano/tensor/type.py", line 169, in filter 
data.shape)) 
TypeError: ('Bad input argument to theano function with name "/Library/Python/2.7/site-packages/keras/backend/theano_backend.py:362" at index 0(0-based)', 'Wrong number of dimensions: expected 3, got 2 with shape (88, 88).') 

Các lỗi nói với tôi rằng nó có sai số kích thước, nó phải là 3 và nó chỉ có 2. Kích thước mà nó đang đề cập đến là gì?

Trả lời

8

Bạn đang cố gắng chạy RNN. Điều này có nghĩa là bạn muốn bao gồm các bước thời gian trước trong tính toán của bạn. Để làm như vậy, bạn phải xử lý trước dữ liệu của bạn trước khi đưa nó vào lớp SimpleRNN.

Để đơn giản, hãy giả sử rằng thay vì 88 mẫu có 88 tính năng, mỗi mẫu có 8 mẫu với 4 tính năng. Bây giờ, khi sử dụng RNN, bạn sẽ phải quyết định mức tối đa cho backpropagation (tức là số bước thời gian trước đó được bao gồm trong phép tính). Trong trường hợp này, bạn có thể chọn bao gồm tối đa 2 bước thời gian trước đó. Do đó, để tính toán trọng số của RNN, bạn sẽ phải cung cấp mỗi lần bước đầu vào của bước thời gian hiện tại (với 4 tính năng) và đầu vào của 2 bước thời gian trước đó (với 4 tính năng). Giống như trong hình dung này:

sequence sample0 sample1 sample2 sample3 sample4 sample5 sample6 sample7  
    0  |-----------------------| 
    1     |-----------------------| 
    2       |-----------------------| 
    3         |-----------------------| 
    4            |----------------------| 
    5              |----------------------| 

Vì vậy, thay vì đưa ra một (nb_samples, nb_features) ma trận như một đầu vào cho SimpleRNN, bạn sẽ phải cung cấp cho nó một (nb_sequences, nb_timesteps, nb_features) đầu vào hình. Trong ví dụ này, nó có nghĩa là thay vì đưa ra một đầu vào (8x4) bạn cho nó một đầu vào (5x3x4).

Các keras Embedding lớp có thể làm công việc này nhưng trong trường hợp này, bạn cũng có thể viết một mã ngắn cho nó:

input = np.random.rand(8,4) 
nb_timesteps = 3 # 2 (previous) + 1 (current) 
nb_sequences = input.shape[0] - nb_timesteps #8-3=5 

input_3D = np.array([input[i:i+nb_timesteps] for i in range(nb_sequences)]) 
+0

Cảm ơn bạn đã giải thích, tôi có một vấn đề tương tự. Tại sao các bước tối đa cho backpropagation trong trường hợp bạn đang bình luận chỉ là 2? Và tại sao số lượng chuỗi là 5? Bằng cách này, với trình tự bạn có nghĩa là một kỷ nguyên trong đào tạo? – David

+1

Tôi đã chọn ngẫu nhiên số 2 là số lượng các bước trước đó cho backpropagation cho ví dụ này. Kết hợp với số lượng (cũng được tự do lựa chọn) của 8 mẫu, điều này dẫn đến tổng cộng 5 chuỗi. Trình tự 1 bao gồm mẫu 1, 2 và 3, trình tự 2 bao gồm 2, 3 và 4, v.v. Xem trực quan trong câu trả lời của tôi để biết chi tiết. Một kỷ nguyên được thực hiện khi mạng đã được đào tạo với mỗi chuỗi một lần. Sau đó, bạn bắt đầu với chuỗi đầu tiên một lần nữa – Lorrit

4

Các lỗi có lẽ là bởi vì kích thước đầu vào của bạn không có định dạng của:

(nb_samples, timesteps, input_dim) 

Nó được mong đợi 3 chiều, và bạn đang cung cấp chỉ có 2 trong số họ (88,88).

+0

nổi bật @Tarantula! Độ mờ đầu vào là số cột của dataframe dự đoán của tôi, phải không? Timesteps và nb_samples là gì? Hàng và kích thước khung hình? – abutremutante

Các vấn đề liên quan