Tôi đã tải một khuôn mặt VGG được đào tạo trước CNN và đã chạy thành công. Tôi muốn trích xuất trung bình hyper-cột từ các lớp 3 và 8. Tôi đã làm theo phần về trích xuất các cột siêu từ here. Tuy nhiên, kể từ khi chức năng get_output đã không làm việc, tôi đã phải thực hiện một vài thay đổi:Keras Tính năng trích xuất VGG
Nhập khẩu:
import matplotlib.pyplot as plt
import theano
from scipy import misc
import scipy as sp
from PIL import Image
import PIL.ImageOps
from keras.models import Sequential
from keras.layers.core import Flatten, Dense, Dropout
from keras.layers.convolutional import Convolution2D, MaxPooling2D, ZeroPadding2D
from keras.optimizers import SGD
import numpy as np
from keras import backend as K
chức năng chính:
#after necessary processing of input to get im
layers_extract = [3, 8]
hc = extract_hypercolumn(model, layers_extract, im)
ave = np.average(hc.transpose(1, 2, 0), axis=2)
print(ave.shape)
plt.imshow(ave)
plt.show()
Nhận tính năng chức năng: (Tôi theo this)
def get_features(model, layer, X_batch):
get_features = K.function([model.layers[0].input, K.learning_phase()], [model.layers[layer].output,])
features = get_features([X_batch,0])
return features
Trích xuất cột siêu:
def extract_hypercolumn(model, layer_indexes, instance):
layers = [K.function([model.layers[0].input],[model.layers[li].output])([instance])[0] for li in layer_indexes]
feature_maps = get_features(model,layers,instance)
hypercolumns = []
for convmap in feature_maps:
for fmap in convmap[0]:
upscaled = sp.misc.imresize(fmap, size=(224, 224),mode="F", interp='bilinear')
hypercolumns.append(upscaled)
return np.asarray(hypercolumns)
Tuy nhiên, khi tôi chạy mã này, tôi nhận được lỗi sau:
get_features = K.function([model.layers[0].input, K.learning_phase()], [model.layers[layer].output,])
TypeError: list indices must be integers, not list
Làm thế nào tôi có thể sửa lỗi này?
LƯU Ý:
Trong chức năng khai thác hyper-cột, khi tôi sử dụng feature_maps = get_features(model,1,instance)
hoặc bất kỳ số nguyên ở vị trí số 1, nó hoạt động tốt. Nhưng tôi muốn trích xuất trung bình từ lớp 3 đến 8.