首页
登录 | 注册

tensorflow 的tfrecord格式转为图片

具体来说,我认为数据增强肯定是要把tfrecord转为图片,才能增强嘛。tfrecord格式读取更快,比图片读取。

当然有人说我就是喜欢拿原来图片直接数据增强,额,简单直接,能达到目的也行。

再一个上篇我实现了image-->tfrecord格式,所以为了验证对否,还是要转回来看看。

总的来说只有3步:

1.读取tfrecords,只是读取器变成了tf.TFRecordReader来读取tfrecord文件。

2.通过一个解析器tf.parse_single_example ,解析这个特殊的tfrecord格式文件。

3.然后用解码器 tf.decode_raw 解码。

========

效果如下:

tensorflow 的tfrecord格式转为图片

直接上代码:

# -*- coding: utf-8 -*-

import tensorflow as tf
from PIL import Image  

#写入将要保存图片路径,需要自己手动新建文件夹
swd = './tfrecord2pic'+'/'
#TFRecord文件路径,只能打开某一个具体的tfrecord,有多个那就改一下咯。
data_path = './traindata.tfrecords-003'
# 获取文件名列表
data_files = tf.gfile.Glob(data_path)
# 文件名列表生成器
filename_queue = tf.train.string_input_producer(data_files,shuffle=True)

reader = tf.TFRecordReader()

#上一篇说了,tfrecord格式数据度保存在值里面,即serialized_example,所以键不管

_, serialized_example = reader.read(filename_queue)   #返回文件名和文件
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'img_raw' : tf.FixedLenFeature([], tf.string),
                                       'img_width': tf.FixedLenFeature([], tf.int64),
                                       'img_height': tf.FixedLenFeature([], tf.int64),
                                   })  #取出包含image和label的feature对象
#tf.decode_raw可以将字符串解析成图像对应的像素数组
image = tf.decode_raw(features['img_raw'], tf.uint8)
height = tf.cast(features['img_height'],tf.int32)
width = tf.cast(features['img_width'],tf.int32)
label = tf.cast(features['label'], tf.int32)
channel = 3
image = tf.reshape(image, [height,width,channel])

with tf.Session() as sess: #开始一个会话
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    #启动多线程
    coord=tf.train.Coordinator()

    threads= tf.train.start_queue_runners(coord=coord)

#循环6次,所以转化了6张图片

    for i in range(6):
        single,l = sess.run([image,label])#在会话中取出image和label

        img=Image.fromarray(single, 'RGB')#这里Image是之前提到的

#存下图片,格式是  第几张图片_label_所属类别标签号

        img.save(swd+str(i)+'_''Label_'+str(l)+'.jpg')
    coord.request_stop()

    coord.join(threads)

====提示

1.红色字体是需要修改的部分,看注释改下吧。

2.每次读取都是从tfrecord第一张图片开始读取的,for i in range(xxx,xxxx):这里设置图片的编号,

同时还设置转化的图片数目。你如果想从第10张以后开始读取保存图片,简单:

你写个if语句,判断循环了多少次嘛。到了第10次才开始保存即可。

或者你想跳过中间某些图片不处理,还是写个if语句,count在那个范围之内你再读取嘛。

3.比如你tfrecord有300张图片,你设置for i in range(xxx,xxxx)读取500张,

   其实后面200张图片又从tfrecord头开始读取重复图片了,不会保错。




2020 jeepxie.net webmaster#jeepxie.net
10 q. 0.008 s.
京ICP备10005923号