本月1日起,上海正式开始了“史上最严“垃圾分类的规定,扔错垃圾最高可罚200元。全国其它46个城市也要陆续步入垃圾分类新时代。各种被垃圾分类逼疯的段子在社交媒体上层出不穷。
其实从人工智能的角度看垃圾分类就是图像处理中图像分类任务的一种应用,而这在2012年以来的ImageNet图像分类任务的评比中,SENet模型以top-5测试集回归2.25%错误率的成绩可谓是技压群雄,堪称目前最强的图像分类器。
笔者刚刚还到SENet的创造者momenta公司的网站上看了一下,他们最新的方向已经是3D物体识别和标定了,效果如下:
可以说他们提出的SENet进行垃圾图像处理是完全没问题的。
Senet简介
Senet的是由momenta和牛津大学共同提出的一种基于挤压(squeeze)和激励(Excitation)的模型,每个模块通过“挤压”操作嵌入来自全局感受野的信息,并且通过“激励”操作选择性地诱导响应增强。我们可以看到历年的ImageNet冠军基本都是在使用加大模型数量和连接数量的方式来提高精度,而Senet在这种”大力出奇迹”的潮流下明显是一股清流。其论文地址如下:http://openaccess.thecvf.com/content_cvpr_2018/papers/Hu_Squeeze-and-Excitation_Networks_CVPR_2018_paper.pdf
其具体原理说明如下:
Sequeeze:对 C×H×W 进行 global average pooling,得到 1×1×C 大小的特征图,这个特征图可以理解为具有全局感受野。翻译论文原文来说:将每个二维的特征通道变成一个实数,这个实数某种程度上具有全局的感受野,并且输出的维度和输入的特征通道数相匹配。它表征着在特征通道上响应的全局分布,而且使得靠近输入的层也可以获得全局的感受野。
Excitation :使用一个全连接神经网络,对 Sequeeze 之后的结果做一个非线性变换。它的机制一个类似于循环神经网络中的门。通过参数 w 来为每个特征通道生成权重,其中参数 w 被学习用来显式地建模特征通道间的相关性。
特征重标定:使用 Excitation 得到的结果作为权重,乘到输入特征上。将Excitation输出的权重可以认为是特征通道的重要性反应,逐通道加权到放到先前的特征上,完成对原始特征的重标定。
其模型架构如下:
SENet 构造非常简单,而且很容易被部署,不需要引入新的函数或者层。其caffe模型可以通过百度下载(https://pan.baidu.com/s/1o7HdfAE?errno=0&errmsg=Auth%20Login%20Sucess&&bduss=&ssnerror=0&traceid=)
Senet的运用
如果读者布署有caffe那么直接下载刚刚的模型直接load进来就可以使用了。如果没有装caffe而装了tensorflow也没关系,我们刚刚说了SENet没有引入新的函数和层,很方便用tensorflow实现。
下载图像集:经笔者各方查找发现了这个数据集,虽然不大也没有发挥出SENet的优势,不过也方便使用:
https://raw.githubusercontent.com/garythung/trashnet/master/data/dataset-resized.zip
建立SENet模型:使用tensorflow建立的模型在github上也有开源项目了,网址如下:https://github.com/taki0112/SENet-Tensorflow,只是他使用的是Cifar10数据集,不过这也没关系,只需要在gitclone以下将其cifar10.py中的prepare_data函数做如下修改即可。
1defprepare_data(): 2print("======Loadingdata======") 3download_data() 4data_dir='e:/test/' 5#data_dir='./cifar-10-batches-py'#改为你的文件侠 6image_dim=image_size*image_size*img_channels 7#meta=unpickle(data_dir+'/batches.meta')#本数据集不使用meta文件分类,故需要修改 8label_names=['cardboard','glass','metal','trash','paper','plastic'] 9label_count=len(label_names)10#train_files=['data_batch_%d'%dfordinrange(1,6)]11train_files=[data_dir+sforsinlabel_names]#改为12train_data,train_labels=load_data(train_files,data_dir,label_count)13test_data,test_labels=load_data(['test_batch'],data_dir,label_count)1415print("Traindata:",np.shape(train_data),np.shape(train_labels))16print("Testdata:",np.shape(test_data),np.shape(test_labels))17print("======Loadfinished======")1819print("======Shufflingdata======")20indices=np.random.permutation(len(train_data))21train_data=train_data[indices]22train_labels=train_labels[indices]23print("======PrepareFinished======")2425returntrain_data,train_labels,test_data,test_labels
其最主要的建模代码如下,其主要工作就是将SENet的模型结构实现一下即可:
1importtensorflowastf 2fromtflearn.layers.convimportglobal_avg_pool 3fromtensorflow.contrib.layersimportbatch_norm,flatten 4fromtensorflow.contrib.frameworkimportarg_scope 5fromcifar10import* 6importnumpyasnp 7 8weight_decay=0.0005 9momentum=0.9 10 11init_learning_rate=0.1 12 13reduction_ratio=4 14 15batch_size=128 16iteration=391 17#128*391~50,000 18 19test_iteration=10 20 21total_epochs=100 22 23defconv_layer(input,filter,kernel,stride=1,padding='SAME',layer_name="conv",activation=True): 24withtf.name_scope(layer_name): 25network=tf.layers.conv2d(inputs=input,use_bias=True,filters=filter,kernel_size=kernel,strides=stride,padding=padding) 26ifactivation: 27network=Relu(network) 28returnnetwork 29 30defFully_connected(x,units=class_num,layer_name='fully_connected'): 31withtf.name_scope(layer_name): 32returntf.layers.dense(inputs=x,use_bias=True,units=units) 33 34defRelu(x): 35returntf.nn.relu(x) 36 37defSigmoid(x): 38returntf.nn.sigmoid(x) 39 40defGlobal_Average_Pooling(x): 41returnglobal_avg_pool(x,name='Global_avg_pooling') 42 43defMax_pooling(x,pool_size=[3,3],stride=2,padding='VALID'): 44returntf.layers.max_pooling2d(inputs=x,pool_size=pool_size,strides=stride,padding=padding) 45 46defBatch_Normalization(x,training,scope): 47witharg_scope([batch_norm], 48scope=scope, 49updates_collections=None, 50decay=0.9, 51center=True, 52scale=True, 53zero_debias_moving_mean=True): 54returntf.cond(training, 55lambda:batch_norm(inputs=x,is_training=training,reuse=None), 56lambda:batch_norm(inputs=x,is_training=training,reuse=True)) 57 58defConcatenation(layers): 59returntf.concat(layers,axis=3) 60 61defDropout(x,rate,training): 62returntf.layers.dropout(inputs=x,rate=rate,training=training) 63 64defEvaluate(sess): 65test_acc=0.0 66test_loss=0.0 67test_pre_index=0 68add=1000 69 70foritinrange(test_iteration): 71test_batch_x=test_x[test_pre_index:test_pre_index+add] 72test_batch_y=test_y[test_pre_index:test_pre_index+add] 73test_pre_index=test_pre_index+add 74 75test_feed_dict={ 76x:test_batch_x, 77label:test_batch_y, 78learning_rate:epoch_learning_rate, 79training_flag:False 80} 81 82loss_,acc_=sess.run([cost,accuracy],feed_dict=test_feed_dict) 83 84test_loss+=loss_ 85test_acc+=acc_ 86 87test_loss/=test_iteration#averageloss 88test_acc/=test_iteration#averageaccuracy 89 90summary=tf.Summary(value=[tf.Summary.Value(tag='test_loss',simple_value=test_loss), 91tf.Summary.Value(tag='test_accuracy',simple_value=test_acc)]) 92 93returntest_acc,test_loss,summary 94 95classSE_Inception_resnet_v2(): 96def__init__(self,x,training): 97self.training=training 98self.model=self.Build_SEnet(x) 99100defStem(self,x,scope):101withtf.name_scope(scope):102x=conv_layer(x,filter=32,kernel=[3,3],stride=2,padding='VALID',layer_name=scope+'_conv1')103x=conv_layer(x,filter=32,kernel=[3,3],padding='VALID',layer_name=scope+'_conv2')104block_1=conv_layer(x,filter=64,kernel=[3,3],layer_name=scope+'_conv3')105106split_max_x=Max_pooling(block_1)107split_conv_x=conv_layer(block_1,filter=96,kernel=[3,3],stride=2,padding='VALID',layer_name=scope+'_split_conv1')108x=Concatenation([split_max_x,split_conv_x])109110split_conv_x1=conv_layer(x,filter=64,kernel=[1,1],layer_name=scope+'_split_conv2')111split_conv_x1=conv_layer(split_conv_x1,filter=96,kernel=[3,3],padding='VALID',layer_name=scope+'_split_conv3')112113split_conv_x2=conv_layer(x,filter=64,kernel=[1,1],layer_name=scope+'_split_conv4')114split_conv_x2=conv_layer(split_conv_x2,filter=64,kernel=[7,1],layer_name=scope+'_split_conv5')115split_conv_x2=conv_layer(split_conv_x2,filter=64,kernel=[1,7],layer_name=scope+'_split_conv6')116split_conv_x2=conv_layer(split_conv_x2,filter=96,kernel=[3,3],padding='VALID',layer_name=scope+'_split_conv7')117118x=Concatenation([split_conv_x1,split_conv_x2])119120split_conv_x=conv_layer(x,filter=192,kernel=[3,3],stride=2,padding='VALID',layer_name=scope+'_split_conv8')121split_max_x=Max_pooling(x)122123x=Concatenation([split_conv_x,split_max_x])124125x=Batch_Normalization(x,training=self.training,scope=scope+'_batch1')126x=Relu(x)127128returnx129130defInception_resnet_A(self,x,scope):131withtf.name_scope(scope):132init=x133134split_conv_x1=conv_layer(x,filter=32,kernel=[1,1],layer_name=scope+'_split_conv1')135136split_conv_x2=conv_layer(x,filter=32,kernel=[1,1],layer_name=scope+'_split_conv2')137split_conv_x2=conv_layer(split_conv_x2,filter=32,kernel=[3,3],layer_name=scope+'_split_conv3')138139split_conv_x3=conv_layer(x,filter=32,kernel=[1,1],layer_name=scope+'_split_conv4')140split_conv_x3=conv_layer(split_conv_x3,filter=48,kernel=[3,3],layer_name=scope+'_split_conv5')141split_conv_x3=conv_layer(split_conv_x3,filter=64,kernel=[3,3],layer_name=scope+'_split_conv6')142143x=Concatenation([split_conv_x1,split_conv_x2,split_conv_x3])144x=conv_layer(x,filter=384,kernel=[1,1],layer_name=scope+'_final_conv1',activation=False)145146x=x*0.1147x=init+x148149x=Batch_Normalization(x,training=self.training,scope=scope+'_batch1')150x=Relu(x)151152returnx153154defInception_resnet_B(self,x,scope):155withtf.name_scope(scope):156init=x157158split_conv_x1=conv_layer(x,filter=192,kernel=[1,1],layer_name=scope+'_split_conv1')159160split_conv_x2=conv_layer(x,filter=128,kernel=[1,1],layer_name=scope+'_split_conv2')161split_conv_x2=conv_layer(split_conv_x2,filter=160,kernel=[1,7],layer_name=scope+'_split_conv3')162split_conv_x2=conv_layer(split_conv_x2,filter=192,kernel=[7,1],layer_name=scope+'_split_conv4')163164x=Concatenation([split_conv_x1,split_conv_x2])165x=conv_layer(x,filter=1152,kernel=[1,1],layer_name=scope+'_final_conv1',activation=False)166#1154167x=x*0.1168x=init+x169170x=Batch_Normalization(x,training=self.training,scope=scope+'_batch1')171x=Relu(x)172173returnx174175defInception_resnet_C(self,x,scope):176withtf.name_scope(scope):177init=x178179split_conv_x1=conv_layer(x,filter=192,kernel=[1,1],layer_name=scope+'_split_conv1')180181split_conv_x2=conv_layer(x,filter=192,kernel=[1,1],layer_name=scope+'_split_conv2')182split_conv_x2=conv_layer(split_conv_x2,filter=224,kernel=[1,3],layer_name=scope+'_split_conv3')183split_conv_x2=conv_layer(split_conv_x2,filter=256,kernel=[3,1],layer_name=scope+'_split_conv4')184185x=Concatenation([split_conv_x1,split_conv_x2])186x=conv_layer(x,filter=2144,kernel=[1,1],layer_name=scope+'_final_conv2',activation=False)187#2048188x=x*0.1189x=init+x190191x=Batch_Normalization(x,training=self.training,scope=scope+'_batch1')192x=Relu(x)193194returnx195196defReduction_A(self,x,scope):197withtf.name_scope(scope):198k=256199l=256200m=384201n=384202203split_max_x=Max_pooling(x)204205split_conv_x1=conv_layer(x,filter=n,kernel=[3,3],stride=2,padding='VALID',layer_name=scope+'_split_conv1')206207split_conv_x2=conv_layer(x,filter=k,kernel=[1,1],layer_name=scope+'_split_conv2')208split_conv_x2=conv_layer(split_conv_x2,filter=l,kernel=[3,3],layer_name=scope+'_split_conv3')209split_conv_x2=conv_layer(split_conv_x2,filter=m,kernel=[3,3],stride=2,padding='VALID',layer_name=scope+'_split_conv4')210211x=Concatenation([split_max_x,split_conv_x1,split_conv_x2])212213x=Batch_Normalization(x,training=self.training,scope=scope+'_batch1')214x=Relu(x)215216returnx217218defReduction_B(self,x,scope):219withtf.name_scope(scope):220split_max_x=Max_pooling(x)221222split_conv_x1=conv_layer(x,filter=256,kernel=[1,1],layer_name=scope+'_split_conv1')223split_conv_x1=conv_layer(split_conv_x1,filter=384,kernel=[3,3],stride=2,padding='VALID',layer_name=scope+'_split_conv2')224225split_conv_x2=conv_layer(x,filter=256,kernel=[1,1],layer_name=scope+'_split_conv3')226split_conv_x2=conv_layer(split_conv_x2,filter=288,kernel=[3,3],stride=2,padding='VALID',layer_name=scope+'_split_conv4')227228split_conv_x3=conv_layer(x,filter=256,kernel=[1,1],layer_name=scope+'_split_conv5')229split_conv_x3=conv_layer(split_conv_x3,filter=288,kernel=[3,3],layer_name=scope+'_split_conv6')230split_conv_x3=conv_layer(split_conv_x3,filter=320,kernel=[3,3],stride=2,padding='VALID',layer_name=scope+'_split_conv7')231232x=Concatenation([split_max_x,split_conv_x1,split_conv_x2,split_conv_x3])233234x=Batch_Normalization(x,training=self.training,scope=scope+'_batch1')235x=Relu(x)236237returnx238239defSqueeze_excitation_layer(self,input_x,out_dim,ratio,layer_name):240withtf.name_scope(layer_name):241242243squeeze=Global_Average_Pooling(input_x)244245excitation=Fully_connected(squeeze,units=out_dim/ratio,layer_name=layer_name+'_fully_connected1')246excitation=Relu(excitation)247excitation=Fully_connected(excitation,units=out_dim,layer_name=layer_name+'_fully_connected2')248excitation=Sigmoid(excitation)249250excitation=tf.reshape(excitation,[-1,1,1,out_dim])251scale=input_x*excitation252253returnscale254255defBuild_SEnet(self,input_x):256input_x=tf.pad(input_x,[[0,0],[32,32],[32,32],[0,0]])257#size32->96258print(np.shape(input_x))259#onlycifar10architecture260261x=self.Stem(input_x,scope='stem')262263foriinrange(5):264x=self.Inception_resnet_A(x,scope='Inception_A'+str(i))265channel=int(np.shape(x)[-1])266x=self.Squeeze_excitation_layer(x,out_dim=channel,ratio=reduction_ratio,layer_name='SE_A'+str(i))267268x=self.Reduction_A(x,scope='Reduction_A')269270channel=int(np.shape(x)[-1])271x=self.Squeeze_excitation_layer(x,out_dim=channel,ratio=reduction_ratio,layer_name='SE_A')272273foriinrange(10):274x=self.Inception_resnet_B(x,scope='Inception_B'+str(i))275channel=int(np.shape(x)[-1])276x=self.Squeeze_excitation_layer(x,out_dim=channel,ratio=reduction_ratio,layer_name='SE_B'+str(i))277278x=self.Reduction_B(x,scope='Reduction_B')279280channel=int(np.shape(x)[-1])281x=self.Squeeze_excitation_layer(x,out_dim=channel,ratio=reduction_ratio,layer_name='SE_B')282283foriinrange(5):284x=self.Inception_resnet_C(x,scope='Inception_C'+str(i))285channel=int(np.shape(x)[-1])286x=self.Squeeze_excitation_layer(x,out_dim=channel,ratio=reduction_ratio,layer_name='SE_C'+str(i))287288289#channel=int(np.shape(x)[-1])290#x=self.Squeeze_excitation_layer(x,out_dim=channel,ratio=reduction_ratio,layer_name='SE_C')291292x=Global_Average_Pooling(x)293x=Dropout(x,rate=0.2,training=self.training)294x=flatten(x)295296x=Fully_connected(x,layer_name='final_fully_connected')297returnx298299300train_x,train_y,test_x,test_y=prepare_data()301train_x,test_x=color_preprocessing(train_x,test_x)302303304#image_size=32,img_channels=3,class_num=10incifar10305x=tf.placeholder(tf.float32,shape=[None,image_size,image_size,img_channels])306label=tf.placeholder(tf.float32,shape=[None,class_num])307308training_flag=tf.placeholder(tf.bool)309310311learning_rate=tf.placeholder(tf.float32,name='learning_rate')312313logits=SE_Inception_resnet_v2(x,training=training_flag).model314cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=label,logits=logits))315316l2_loss=tf.add_n([tf.nn.l2_loss(var)forvarintf.trainable_variables()])317optimizer=tf.train.MomentumOptimizer(learning_rate=learning_rate,momentum=momentum,use_nesterov=True)318train=optimizer.minimize(cost+l2_loss*weight_decay)319320correct_prediction=tf.equal(tf.argmax(logits,1),tf.argmax(label,1))321accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))322323saver=tf.train.Saver(tf.global_variables())324325withtf.Session()assess:326ckpt=tf.train.get_checkpoint_state('./model')327ifckptandtf.train.checkpoint_exists(ckpt.model_checkpoint_path):328saver.restore(sess,ckpt.model_checkpoint_path)329else:330sess.run(tf.global_variables_initializer())331332summary_writer=tf.summary.FileWriter('./logs',sess.graph)333334epoch_learning_rate=init_learning_rate335forepochinrange(1,total_epochs+1):336ifepoch%30==0:337epoch_learning_rate=epoch_learning_rate/10338339pre_index=0340train_acc=0.0341train_loss=0.0342343forstepinrange(1,iteration+1):344ifpre_index+batch_size< 50000:345 batch_x = train_x[pre_index: pre_index + batch_size]346 batch_y = train_y[pre_index: pre_index + batch_size]347 else:348 batch_x = train_x[pre_index:]349 batch_y = train_y[pre_index:]350351 batch_x = data_augmentation(batch_x)352353 train_feed_dict = {354 x: batch_x,355 label: batch_y,356 learning_rate: epoch_learning_rate,357 training_flag: True358 }359360 _, batch_loss = sess.run([train, cost], feed_dict=train_feed_dict)361 batch_acc = accuracy.eval(feed_dict=train_feed_dict)362363 train_loss += batch_loss364 train_acc += batch_acc365 pre_index += batch_size366367368 train_loss /= iteration # average loss369 train_acc /= iteration # average accuracy370371 train_summary = tf.Summary(value=[tf.Summary.Value(tag='train_loss', simple_value=train_loss),372 tf.Summary.Value(tag='train_accuracy', simple_value=train_acc)])373374 test_acc, test_loss, test_summary = Evaluate(sess)375376 summary_writer.add_summary(summary=train_summary, global_step=epoch)377 summary_writer.add_summary(summary=test_summary, global_step=epoch)378 summary_writer.flush()379380 line = "epoch: %d/%d, train_loss: %.4f, train_acc: %.4f, test_loss: %.4f, test_acc: %.4f " % (381 epoch, total_epochs, train_loss, train_acc, test_loss, test_acc)382 print(line)383384 with open('logs.txt', 'a') as f:385 f.write(line)386387 saver.save(sess=sess, save_path='./model/Inception_resnet_v2.ckpt')
其实使用SENet做垃圾分类真是大才小用了,不过大家也可以感受一下他的实力强大。
-
神经网络
+关注
关注
42文章
4538浏览量
98438 -
图像处理
+关注
关注
26文章
1209浏览量
55689 -
人工智能
+关注
关注
1773文章
43371浏览量
230134
原文标题:还在纠结垃圾分类问题?带你用Python感受ImageNet冠军模型SENet的强大
文章出处:【微信号:rgznai100,微信公众号:rgznai100】欢迎添加关注!文章转载请注明出处。
发布评论请先 登录
相关推荐
评论