本小节主要讲的是如何将数据embedding到tensorflow中显示出来,代码如下:

#!/usr/bin/env python
# -*- coding:utf-8 -*- 
#Author: Xusong Chen

import os
import sys
import argparse

import numpy as np
import tensorflow as tf
from tensorflow.contrib.tensorboard.plugins import projector
from tensorflow.examples.tutorials.mnist import input_data

FLAGS = None

def generate_embedding():
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=False)
    metadata = mnist.test.images[:FLAGS.image_num]
    np.savetxt(os.path.join(FLAGS.log_dir, 'metadata.tsv'), mnist.test.labels[:FLAGS.image_num], fmt='%d')
    PATH_TO_SPRITE_IMAGE = os.path.join(FLAGS.data_dir, 'mnist_10k_sprite.png')

    sess = tf.InteractiveSession()

    # setup a 2D tensor variable that holds your embedding
    embedding_var = tf.Variable(metadata, name='embedding')
    tf.global_variables_initializer().run()

    # Periodically save your embeddings in a LOG_DIR
    saver = tf.train.Saver()
    saver.save(sess, os.path.join(FLAGS.log_dir, "model.ckpt"), global_step=0)
    # Associate metadata with your embedding.

    # Use the same LOG_DIR where you stored your checkpoint.
    summary_writer = tf.summary.FileWriter(FLAGS.log_dir)

    # Format: tensorflow/contrib/tensorboard/plugins/projector/projector_config.proto
    config = projector.ProjectorConfig()
    embedding = config.embeddings.add()
    embedding.tensor_name = embedding_var.name
    # Link this tensor to its metadata file (e.g. labels).
    embedding.metadata_path = os.path.join(FLAGS.log_dir, 'metadata.tsv')
    # Link this tensor to its sprite image.
    embedding.sprite.image_path = PATH_TO_SPRITE_IMAGE
    embedding.sprite.single_image_dim.extend([28, 28])
    # Saves a configuration file that TensorBoard will read during startup.
    projector.visualize_embeddings(summary_writer, config)


def main(_):
    if tf.gfile.Exists(FLAGS.log_dir):
        tf.gfile.DeleteRecursively(FLAGS.log_dir)
    tf.gfile.MakeDirs(FLAGS.log_dir)
    generate_embedding()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_num', type=int, default=10000,
                        help='Number of images')
    parser.add_argument('--log_dir', type=str, default='/tmp/tensorflow/mnist/embedding/logs/',
                        help='directory to store summaries')
    parser.add_argument('--data_dir', type=str,
                        default='/home/chen/projects/PycharmProjects/deep-learning/learning_tensorflow/MNIST_data',
                        help='directory for storing mnist data')

    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

然后运行tensorboard --logdir=LOG_DIR来查看,LOG_DIR为存储summaries的地址。

代码参考1,2.

什么是one hot?input_data.read_data_sets中有这样一个参数one_hot,根据Quora上的回答:one hot是一种编码方式,在MNIST中,如果使用one hot编码,那么

5 = 0000010000
2 = 0010000000

所以在这里,如果将参数one_hot设为True,那么在tensorboard上显示的结果就是如下所示:

one_hot 可以看出,每个label显示的都是一个由01组成的向量,1所在的位置代表该数字是几。那么将其设为False,显示的就是一个整数(它自身),如下图所示:

not_one_hot
图 not_one_hot

小结:所以在tensorboard上显示embedding的数据的流程为:

  1. 读数据进来,保存为Tensor的形式;
  2. 设置保存数据的路径;
  3. 关联数据和你的embedding。

results matching ""

    No results matching ""