2015-11-19 15 views
8

Đây là mã mà tôi đang cố gắng để Run-Làm cách nào để thay đổi loại dtype trong TensorFlow cho tệp csv?

import tensorflow as tf 
import numpy as np 
import input_data 

filename_queue = tf.train.string_input_producer(["cs-training.csv"]) 

reader = tf.TextLineReader() 
key, value = reader.read(filename_queue) 

record_defaults = [[1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1]] 
col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11 = tf.decode_csv(
    value, record_defaults=record_defaults) 
features = tf.concat(0, [col2, col3, col4, col5, col6, col7, col8, col9, col10, col11]) 

with tf.Session() as sess: 
    # Start populating the filename queue. 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    for i in range(1200): 
    # Retrieve a single instance: 
    print i 
    example, label = sess.run([features, col1]) 
    try: 
     print example, label 
    except: 
     pass 

    coord.request_stop() 
    coord.join(threads) 

Mã này trả lại lỗi dưới đây.

--------------------------------------------------------------------------- 
InvalidArgumentError      Traceback (most recent call last) 
<ipython-input-23-e42fe2609a15> in <module>() 
     7  # Retrieve a single instance: 
     8  print i 
----> 9  example, label = sess.run([features, col1]) 
    10  try: 
    11   print example, label 

/root/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict) 
    343 
    344  # Run request and get response. 
--> 345  results = self._do_run(target_list, unique_fetch_targets, feed_dict_string) 
    346 
    347  # User may have fetched the same tensor multiple times, but we 

/root/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, target_list, fetch_list, feed_dict) 
    417   # pylint: disable=protected-access 
    418   raise errors._make_specific_exception(node_def, op, e.error_message, 
--> 419            e.code) 
    420   # pylint: enable=protected-access 
    421  raise e_type, e_value, e_traceback 

InvalidArgumentError: Field 1 in record 0 is not a valid int32: 0.766126609 

Nó có rất nhiều thông tin mà tôi nghĩ là không liên quan đến vấn đề. Rõ ràng vấn đề là rất nhiều dữ liệu mà tôi đang cho chương trình không phải là của intty dtype. Nó chủ yếu là số phao. Tôi đã thử một vài điều để thay đổi kiểu dtype như thiết lập rõ ràng đối số dtype=float trong tf.decode_csv cũng như tf.concat. Không làm việc. Đó là một đối số không hợp lệ. Để đầu nó tất cả ra, tôi không biết nếu mã này sẽ thực sự làm cho một dự đoán trên dữ liệu. Tôi muốn nó dự đoán liệu col1 sẽ là 1 hay 0 và tôi không thấy bất cứ điều gì trong đoạn mã gợi ý rằng nó sẽ thực sự đưa ra dự đoán đó. Có lẽ tôi sẽ lưu câu hỏi đó cho một chủ đề khác. Bất kỳ trợ giúp nào cũng được đánh giá rất cao!

Trả lời

1

Câu trả lời cho việc thay đổi dtype là chỉ cần thay đổi giá trị mặc định như Somali

record_defaults = [[1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.]] 

Sau khi bạn làm điều đó, nếu bạn in ra col1, bạn sẽ nhận được thông báo này.

Tensor("DecodeCSV_43:0", shape=TensorShape([]), dtype=float32) 

Nhưng có một lỗi mà bạn sẽ chạy vào, which has been answered here. Để tóm tắt lại các câu trả lời, thực hiện giải pháp là thay đổi tf.concat-tf.pack như vậy.

features = tf.pack([col2, col3, col4, col5, col6, col7, col8, col9, col10, col11]) 
13

Giao diện tf.decode_csv() hơi phức tạp. Các dtype của mỗi cột được xác định bởi các yếu tố tương ứng của đối số record_defaults. Giá trị cho record_defaults trong mã của bạn được hiểu là mỗi cột có tf.int32 làm loại của nó, dẫn đến lỗi khi gặp phải dữ liệu dấu phẩy động.

Hãy nói rằng bạn có những dữ liệu CSV sau, chứa ba cột số nguyên, tiếp theo là một điểm cột nổi:

4, 8, 9, 4.5 
2, 5, 1, 3.7 
2, 2, 2, 0.1 

Giả sử tất cả các cột được cần, bạn sẽ xây dựng record_defaults như sau:

value = ... 

record_defaults = [tf.constant([], dtype=tf.int32), # Column 0 
        tf.constant([], dtype=tf.int32), # Column 1 
        tf.constant([], dtype=tf.int32), # Column 2 
        tf.constant([], dtype=tf.float32)] # Column 3 

col0, col1, col2, col3 = tf.decode_csv(value, record_defaults=record_defauts) 

assert col0.dtype == tf.int32 
assert col1.dtype == tf.int32 
assert col2.dtype == tf.int32 
assert col3.dtype == tf.float32 

Giá trị rỗng trong record_defaults biểu thị rằng giá trị là bắt buộc. Ngoài ra, nếu (ví dụ) cột 2 được phép có giá trị mất tích, bạn sẽ xác định record_defaults như sau:

record_defaults = [tf.constant([], dtype=tf.int32),  # Column 0 
        tf.constant([], dtype=tf.int32),  # Column 1 
        tf.constant([0], dtype=tf.int32), # Column 2 
        tf.constant([], dtype=tf.float32)] # Column 3 

Phần thứ hai của mối quan tâm câu hỏi của bạn làm thế nào để xây dựng và đào tạo một mô hình mà dự đoán giá trị của một trong những các cột từ dữ liệu đầu vào. Hiện tại, chương trình không: nó chỉ đơn giản là nối các cột vào một tensor đơn, gọi là features. Bạn sẽ cần xác định và đào tạo một mô hình, giải thích dữ liệu đó. Một trong những cách tiếp cận đơn giản nhất như vậy là hồi quy tuyến tính và bạn có thể tìm thấy hướng dẫn này trên linear regression in TensorFlow thích ứng với vấn đề của bạn.

Các vấn đề liên quan