2016-04-02 34 views
6

Tôi đã cố gắng tiêm các vectơ word2vec được pretrained vào mô hình seq2seq tensorflow hiện có.Tiêm vectơ word2vec đã được đào tạo trước vào TensorFlow seq2seq

Sau this answer, tôi đã tạo mã sau đây. Nhưng nó dường như không cải thiện hiệu suất như nó cần, mặc dù các giá trị trong biến được cập nhật.

Trong sự hiểu biết của tôi, lỗi có thể là do thực tế là EmbeddingWrapper hoặc embedding_attention_decoder tạo tệp nhúng độc lập với thứ tự từ vựng?

Cách tốt nhất để tải véc tơ giả vờ vào mô hình tensorflow là gì?

SOURCE_EMBEDDING_KEY = "embedding_attention_seq2seq/RNN/EmbeddingWrapper/embedding" 
TARGET_EMBEDDING_KEY = "embedding_attention_seq2seq/embedding_attention_decoder/embedding" 


def inject_pretrained_word2vec(session, word2vec_path, input_size, dict_dir, source_vocab_size, target_vocab_size): 
    word2vec_model = word2vec.load(word2vec_path, encoding="latin-1") 
    print("w2v model created!") 
    session.run(tf.initialize_all_variables()) 

    assign_w2v_pretrained_vectors(session, word2vec_model, SOURCE_EMBEDDING_KEY, source_vocab_path, source_vocab_size) 
    assign_w2v_pretrained_vectors(session, word2vec_model, TARGET_EMBEDDING_KEY, target_vocab_path, target_vocab_size) 


def assign_w2v_pretrained_vectors(session, word2vec_model, embedding_key, vocab_path, vocab_size): 
    vectors_variable = [v for v in tf.trainable_variables() if embedding_key in v.name] 
    if len(vectors_variable) != 1: 
     print("Word vector variable not found or too many. key: " + embedding_key) 
     print("Existing embedding trainable variables:") 
     print([v.name for v in tf.trainable_variables() if "embedding" in v.name]) 
     sys.exit(1) 

    vectors_variable = vectors_variable[0] 
    vectors = vectors_variable.eval() 

    with gfile.GFile(vocab_path, mode="r") as vocab_file: 
     counter = 0 
     while counter < vocab_size: 
      vocab_w = vocab_file.readline().replace("\n", "") 
      # for each word in vocabulary check if w2v vector exist and inject. 
      # otherwise dont change the value. 
      if word2vec_model.__contains__(vocab_w): 
       w2w_word_vector = word2vec_model.get_vector(vocab_w) 
       vectors[counter] = w2w_word_vector 
      counter += 1 

    session.run([vectors_variable.initializer], 
      {vectors_variable.initializer.inputs[1]: vectors}) 

Trả lời

5

Tôi không quen thuộc với ví dụ seq2seq, nhưng nói chung bạn có thể sử dụng đoạn mã sau đây để tiêm embeddings của bạn:

Nơi bạn xây dựng bạn tạo biểu đồ:

with tf.device("/cpu:0"): 
    embedding = tf.get_variable("embedding", [vocabulary_size, embedding_size])  
    inputs = tf.nn.embedding_lookup(embedding, input_data) 

Khi bạn thực thi (sau khi tạo biểu đồ của bạn và trước khi nêu rõ khóa đào tạo), chỉ cần chỉ định các nhúng đã nhúng của bạn cho biến nhúng:

session.run(tf.assign(embedding, embeddings_that_you_want_to_use)) 

Ý tưởng là embedding_lookup sẽ thay thế các giá trị input_data bằng những giá trị hiện tại trong biến số embedding.

Các vấn đề liên quan