本小节主要讲的是如何将数据embedding到tensorflow中显示出来,代码如下:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
#Author: Xusong Chen
import os
import sys
import argparse
import numpy as np
import tensorflow as tf
from tensorflow.contrib.tensorboard.plugins import projector
from tensorflow.examples.tutorials.mnist import input_data
FLAGS = None
def generate_embedding():
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=False)
metadata = mnist.test.images[:FLAGS.image_num]
np.savetxt(os.path.join(FLAGS.log_dir, 'metadata.tsv'), mnist.test.labels[:FLAGS.image_num], fmt='%d')
PATH_TO_SPRITE_IMAGE = os.path.join(FLAGS.data_dir, 'mnist_10k_sprite.png')
sess = tf.InteractiveSession()
# setup a 2D tensor variable that holds your embedding
embedding_var = tf.Variable(metadata, name='embedding')
tf.global_variables_initializer().run()
# Periodically save your embeddings in a LOG_DIR
saver = tf.train.Saver()
saver.save(sess, os.path.join(FLAGS.log_dir, "model.ckpt"), global_step=0)
# Associate metadata with your embedding.
# Use the same LOG_DIR where you stored your checkpoint.
summary_writer = tf.summary.FileWriter(FLAGS.log_dir)
# Format: tensorflow/contrib/tensorboard/plugins/projector/projector_config.proto
config = projector.ProjectorConfig()
embedding = config.embeddings.add()
embedding.tensor_name = embedding_var.name
# Link this tensor to its metadata file (e.g. labels).
embedding.metadata_path = os.path.join(FLAGS.log_dir, 'metadata.tsv')
# Link this tensor to its sprite image.
embedding.sprite.image_path = PATH_TO_SPRITE_IMAGE
embedding.sprite.single_image_dim.extend([28, 28])
# Saves a configuration file that TensorBoard will read during startup.
projector.visualize_embeddings(summary_writer, config)
def main(_):
if tf.gfile.Exists(FLAGS.log_dir):
tf.gfile.DeleteRecursively(FLAGS.log_dir)
tf.gfile.MakeDirs(FLAGS.log_dir)
generate_embedding()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_num', type=int, default=10000,
help='Number of images')
parser.add_argument('--log_dir', type=str, default='/tmp/tensorflow/mnist/embedding/logs/',
help='directory to store summaries')
parser.add_argument('--data_dir', type=str,
default='/home/chen/projects/PycharmProjects/deep-learning/learning_tensorflow/MNIST_data',
help='directory for storing mnist data')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
然后运行tensorboard --logdir=LOG_DIR
来查看,LOG_DIR
为存储summaries的地址。
什么是one hot? 在input_data.read_data_sets
中有这样一个参数one_hot
,根据Quora上的回答:one hot是一种编码方式,在MNIST中,如果使用one hot编码,那么
5 = 0000010000
2 = 0010000000
所以在这里,如果将参数one_hot
设为True
,那么在tensorboard上显示的结果就是如下所示:
可以看出,每个label显示的都是一个由01
组成的向量,1
所在的位置代表该数字是几。那么将其设为False
,显示的就是一个整数(它自身),如下图所示:
小结:所以在tensorboard上显示embedding的数据的流程为:
- 读数据进来,保存为Tensor的形式;
- 设置保存数据的路径;
- 关联数据和你的embedding。