2017-08-30 36 views
8

Đối với đào tạo một mô hình LSTM trong Tensorflow, tôi đã có cấu trúc dữ liệu của tôi sang một định dạng tf.train.SequenceExample và lưu trữ nó vào một TFRecord nộp. Bây giờ tôi muốn sử dụng API Số liệu mới để tạo các lô độn để đào tạo. Trong the documentation có ví dụ về việc sử dụng padded_batch, nhưng đối với dữ liệu của tôi, tôi không thể tìm ra giá trị của padded_shapes.Làm cách nào để tạo các lô đệm trong Tensorflow cho tf.train.SequenceExample dữ liệu bằng cách sử dụng API DataSet?

Đối với đọc file TFrecord vào lô Tôi đã viết mã Python sau:

import math 
import tensorflow as tf 
import numpy as np 
import struct 
import sys 
import array 

if(len(sys.argv) != 2): 
    print "Usage: createbatches.py [RFRecord file]" 
    sys.exit(0) 


vectorSize = 40 
inFile = sys.argv[1] 

def parse_function_dataset(example_proto): 
    sequence_features = { 
     'inputs': tf.FixedLenSequenceFeature(shape=[vectorSize], 
              dtype=tf.float32), 
     'labels': tf.FixedLenSequenceFeature(shape=[], 
              dtype=tf.int64)} 

    _, sequence = tf.parse_single_sequence_example(example_proto, sequence_features=sequence_features) 

    length = tf.shape(sequence['inputs'])[0] 
    return sequence['inputs'], sequence['labels'] 

sess = tf.InteractiveSession() 

filenames = tf.placeholder(tf.string, shape=[None]) 
dataset = tf.contrib.data.TFRecordDataset(filenames) 
dataset = dataset.map(parse_function_dataset) 
# dataset = dataset.batch(1) 
dataset = dataset.padded_batch(4, padded_shapes=[None]) 
iterator = dataset.make_initializable_iterator() 

batch = iterator.get_next() 

# Initialize `iterator` with training data. 
training_filenames = [inFile] 
sess.run(iterator.initializer, feed_dict={filenames: training_filenames}) 

print(sess.run(batch)) 

Mã này hoạt động tốt nếu tôi sử dụng dataset = dataset.batch(1) (không có đệm cần thiết trong trường hợp đó), nhưng khi tôi sử dụng padded_batch biến thể, tôi nhận được lỗi sau:

TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: .

bạn có thể giúp tôi tìm ra những gì tôi phải vượt qua cho padded_shapes tham số?

(Tôi biết có rất nhiều ví dụ mã sử dụng luồng và hàng đợi cho điều này, nhưng tôi muốn sử dụng các API DataSet mới cho dự án này)

+0

Cảm ơn Marijn! Câu hỏi của bạn đã giúp tôi rất nhiều! –

Trả lời

6

Bạn cần phải vượt qua một tuple của hình dạng. Trong trường hợp của bạn, bạn phải vượt qua

dataset = dataset.padded_batch(4, padded_shapes=([vectorSize],[None])) 

hoặc thử

dataset = dataset.padded_batch(4, padded_shapes=([None],[None])) 

Kiểm tra code này để biết thêm chi tiết. Tôi đã phải gỡ lỗi phương pháp này để tìm ra lý do tại sao nó không làm việc cho tôi.

+0

Cảm ơn! Điều đó làm cho rất nhiều ý nghĩa. Sau đây làm việc cho ví dụ của tôi: 'padded_shapes = ([None, vectorSize], [None])'. Các tensor đầu tiên là một danh sách các vectơ với kích thước vectorSize và thứ hai là một danh sách với các nhãn số nguyên. –

+0

Giống như bổ sung, 'padded_shapes' nhạy cảm với loại cấu trúc lồng nhau (nếu tập dữ liệu trả về một bộ tuple, padded_shapes cũng phải là một bộ tuple và không phải là danh sách) – Conchylicultor

0

Nếu đối tượng Dataset hiện tại của bạn chứa một bộ tuple, bạn cũng có thể chỉ định hình dạng của từng phần tử đệm.

Ví dụ: tôi có bộ dữ liệu (same_sized_images, Labels) và mỗi nhãn có độ dài khác nhau nhưng có cùng thứ hạng.

def process_label(resized_img, label): 
    # Perfrom some tensor transformations 
    # ...... 

    return resized_img, label 

dataset = dataset.map(process_label) 
dataset = dataset.padded_batch(batch_size, 
           padded_shapes=([None, None, 3], 
               [None, None])) # my label has rank 2 
Các vấn đề liên quan