tensorflow图像叙事功能im2txt

Posted by DuHuazhen on April 1, 2017

tensorflow图像叙事功能https://github.com/tensorflow/models/tree/master/research/im2txt#training-a-model  
本文是基于ubuntu14.04+tensorflow1.6(cpu)+python3.6+anaconda
我的仓库:https://github.com/duhuazhen/Tensorflow_practice/tree/master/im2txt1

下载模型

tensorflow/models 下面有很多模型,可以选择下载整个 models

git clone https://github.com/tensorflow/models.git

但是考虑到速度的原因,在终端里下载的速度尤其的慢,我们选择下载子文件夹 参照知乎上的方法:   将https://github.com/tensorflow/models/tree/master/research/im2txt拷到http://kinolien.github.io/gitzip/中去,直接点击download即可下载。

下载好了之后将 models/research/im2txt/im2txt 文件夹复制到你的工作区。

安装必要的包

首先按照 Github 上 im2txt 的说明,安装所有必需的包

  • Bazel[官网方法](https://docs.bazel.build/versions/master/install.html)        由于我们是在anaconda环境下运行程序,用下面的命令更简单点:

    首先激活我们的环境(我这里的环境是tensorflow):

            source activate tensorflow
    

    然后执行下面的命令来安装bazel

             conda install bazel
    
  • TensorFlow 1.0或更高版本官网方法

    本文安装的是tensorflow1.6版本

         pip install --ignore-installed --upgrade \https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0-cp34-cp34m-linux_x86_64.whl
    
  • NumPy官网方法          由于anaconda自带numpy,这里就不需要安装了     * Natural Language Toolkit (NLTK)
           * 首先安装 NLTK官网方法
     conda install nltk
    
  • 然后下载 NLTK 数据官网方法:速度较慢,我下载了3个多小时。

    下载模型和词汇

    如果要自己训练模型,按照官网的说法,需要先下载几个小时的数据集,然后再训练1~2周,最后还要精调几个星期

训练要花不少时间,所以用训练好的模型,下载地址是

下载之后放在 im2txt/model 文件夹下

    im2txt/
    ......
    model/
        graph.pbtxt
        model.ckpt-2000000
        model.ckpt-2000000.meta

编写脚本

为了进行实验,找了提前训练好的模型,不过由于本文实验在tensorflow 1.0版本之上,需要填好几个坑,这是因为TF1.0和后面的LSTM在命名上出现了差异,所以需要根据错误信息自己修改   |TF1.0|TF1.2| |—-|—–| |lstm/basic_lstm_cell/weights|stm/basic_lstm_cell/kernel| |lstm/basic_lstm_cell/biases| lstm/basic_lstm_cell/bias| 解决方式是,新建 rename_ckpt.py 文件,使用输入以下方法将原有训练模型转化

import tensorflow as tf


# 由于 TensorFlow 的版本不同,所以要根据具体错误信息进行修改

def rename_ckpt():
    vars_to_rename = {
        "lstm/BasicLSTMCell/Linear/Bias": "lstm/basic_lstm_cell/bias",
        "lstm/BasicLSTMCell/Linear/Matrix": "lstm/basic_lstm_cell/kernel"
    }
    new_checkpoint_vars = {}
    reader = tf.train.NewCheckpointReader(
        "/home/hehe/anaconda3/envs/tensorflow/im2txt/model/model.ckpt-2000000"
    )
    for old_name in reader.get_variable_to_shape_map():
        if old_name in vars_to_rename:
            new_name = vars_to_rename[old_name]
        else:
            new_name = old_name
        new_checkpoint_vars[new_name] = tf.Variable(
            reader.get_tensor(old_name))

    init = tf.global_variables_initializer()
    saver = tf.train.Saver(new_checkpoint_vars)

    with tf.Session() as sess:
        sess.run(init)
        saver.save(
            sess,
            "/home/hehe/anaconda3/envs/tensorflow/im2txt/model/newmodel.ckpt-2000000"
        )
    print("checkpoint file rename successful... ")


if __name__ == '__main__':
    rename_ckpt()

在 im2txt 文件夹下新建一个 run.sh 脚本文件,输入以下命令

# 模型文件夹或文件路径
CHECKPOINT_PATH="/home/hehe/anaconda3/envs/tensorflow/im2txt/model/newmodel.ckpt-2000000"
# 词汇文件
VOCAB_FILE="/home/hehe/anaconda3/envs/tensorflow/im2txt/data/word_counts.txt"
# 图片文件,多个图片用逗号隔开
IMAGE_FILE="/home/hehe/anaconda3/envs/tensorflow/im2txt/data/images/3.jpg"

# bazel编译
bazel build -c opt //im2txt:run_inference

# 用参数调用编译后的文件
bazel-bin/im2txt/run_inference \
  --checkpoint_path=${CHECKPOINT_PATH} \
  --vocab_file=${VOCAB_FILE} \
  --input_files=${IMAGE_FILE}

其中的变量用自己的路径代替.

运行脚本

将当前工作目录设置为 im2txt,设置脚本的权限

sudo chmod 777 run.sh

然后将工作目录设置为 im2txt 的上层目录,运行脚本

./im2txt/run.sh

实验结果

谷歌经典的实验用图
3.jpg

Captions for image 3.jpg:
  0) a small dog is eating a slice of pizza . (p=0.000129)
  1) a dog is sitting at a table with a pizza . (p=0.000054)
  2) a dog is sitting at a table with a plate of food . (p=0.000047)

为了测试有效性,我们从百度上随机找了一种图片
6.jpg   结果如下


Captions for image 6.jpg:
  0) a dog laying on the grass with a frisbee in its mouth . (p=0.001007)
  1) a brown and white dog laying on top of a grass covered field . (p=0.000901)
  2) a brown and white dog laying on top of a green field . (p=0.000826)

总体来说效果还是可以的

可惜由于实验硬件太差,要不可以结合inception v4来训练,应该效果会更好。另外,还有中文标签的生成。 其余错误可看参考中的第一个链接
参考:  https://blog.csdn.net/White_Idiot/article/details/78699351

https://blog.csdn.net/johnnie_turbo/article/details/77931506#inception-v3%E6%A8%A1%E5%9E%8B%E4%B8%8B%E8%BD%BD

https://blog.csdn.net/sparkexpert/article/details/70846094

https://blog.csdn.net/gbbb1234/article/details/70543584