2016-07-30 21 views
7

Tôi đã tải một khuôn mặt VGG được đào tạo trước CNN và đã chạy thành công. Tôi muốn trích xuất trung bình hyper-cột từ các lớp 3 và 8. Tôi đã làm theo phần về trích xuất các cột siêu từ here. Tuy nhiên, kể từ khi chức năng get_output đã không làm việc, tôi đã phải thực hiện một vài thay đổi:Keras Tính năng trích xuất VGG

Nhập khẩu:

import matplotlib.pyplot as plt 
import theano 
from scipy import misc 
import scipy as sp 
from PIL import Image 
import PIL.ImageOps 
from keras.models import Sequential 
from keras.layers.core import Flatten, Dense, Dropout 
from keras.layers.convolutional import Convolution2D, MaxPooling2D, ZeroPadding2D 
from keras.optimizers import SGD 
import numpy as np 
from keras import backend as K 

chức năng chính:

#after necessary processing of input to get im 
layers_extract = [3, 8] 
hc = extract_hypercolumn(model, layers_extract, im) 
ave = np.average(hc.transpose(1, 2, 0), axis=2) 
print(ave.shape) 
plt.imshow(ave) 
plt.show() 

Nhận tính năng chức năng: (Tôi theo this)

def get_features(model, layer, X_batch): 
    get_features = K.function([model.layers[0].input, K.learning_phase()], [model.layers[layer].output,]) 
    features = get_features([X_batch,0]) 
    return features 

Trích xuất cột siêu:

def extract_hypercolumn(model, layer_indexes, instance): 
    layers = [K.function([model.layers[0].input],[model.layers[li].output])([instance])[0] for li in layer_indexes] 
    feature_maps = get_features(model,layers,instance) 
    hypercolumns = [] 
    for convmap in feature_maps: 
     for fmap in convmap[0]: 
      upscaled = sp.misc.imresize(fmap, size=(224, 224),mode="F", interp='bilinear') 
      hypercolumns.append(upscaled) 
    return np.asarray(hypercolumns) 

Tuy nhiên, khi tôi chạy mã này, tôi nhận được lỗi sau:

get_features = K.function([model.layers[0].input, K.learning_phase()], [model.layers[layer].output,]) 
TypeError: list indices must be integers, not list 

Làm thế nào tôi có thể sửa lỗi này?

LƯU Ý:

Trong chức năng khai thác hyper-cột, khi tôi sử dụng feature_maps = get_features(model,1,instance) hoặc bất kỳ số nguyên ở vị trí số 1, nó hoạt động tốt. Nhưng tôi muốn trích xuất trung bình từ lớp 3 đến 8.

Trả lời

1

Nó bối rối tôi rất nhiều:

  1. Sau layers = [K.function([model.layers[0].input],[model.layers[li].output])([instance])[0] for li in layer_indexes], lớp là danh sách các tính năng trích xuất.
  2. Và sau đó bạn gửi danh sách đó vào feature_maps = get_features(model,layers,instance).
  3. Trong def get_features(model, layer, X_batch):, thông số thứ hai, cụ thể là layer, được sử dụng để lập chỉ mục trong model.layers[layer].output.

gì bạn muốn là:

  1. feature_maps = get_features(model,layer_indexes,instance): đi qua các chỉ số lớp chứ không phải là tính năng trích xuất.
  2. get_features = K.function([model.layers[0].input, K.learning_phase()], [model.layers [l] .kết quả cho l trong lớp]): không thể sử dụng danh sách để lập chỉ mục danh sách.

Tuy nhiên, chức năng trừu tượng tính năng của bạn được viết khủng khiếp. Tôi đề nghị bạn viết lại mọi thứ thay vì trộn mã.

0

Tôi viết lại hàm của bạn cho một hình ảnh đầu vào kênh đơn (W x H x 1). Có lẽ nó sẽ hữu ích.

def extract_hypercolumn(model, layer_indexes, instance): 
    test_image = instance 
    outputs = [layer.output for layer in model.layers]   # all layer outputs 
    comp_graph = [K.function([model.input]+ [K.learning_phase()], [output]) for output in outputs] # evaluation functions 

    feature_maps = [] 
    for layerIdx in layer_indexes: 
     feature_maps.append(layer_outputs_list[layerIdx][0][0]) 


    hypercolumns = [] 
    for idx, convmap in enumerate(feature_maps): 
     #  vv = np.asarray(convmap) 
     #  print(vv.shape) 
     vv = np.asarray(convmap) 
     print('shape of feature map at layer ', layer_indexes[idx], ' is: ', vv.shape) 

     for i in range(vv.shape[-1]): 
      fmap = vv[:,:,i] 
      upscaled = sp.misc.imresize(fmap, size=(img_width, img_height), 
            mode="F", interp='bilinear') 
      hypercolumns.append(upscaled) 

    # hypc = np.asarray(hypercolumns) 
    # print('shape of hypercolumns ', hypc.shape) 

    return np.asarray(hypercolumns) 
Các vấn đề liên quan