2018年4月22日星期日

加载训练好的模型

加载训练好的模型
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Apr 22 13:52:11 2018

@author: sisyphus
"""


from skimage import io,transform
import tensorflow as tf
import numpy as np



path1 = "/Users/sisyphus/darkflow/sample_img/sample_computer.jpg"
path2 = "/Users/sisyphus/darkflow/sample_img/sample_dog.jpg"
path3 = "/Users/sisyphus/darkflow/sample_img/sample_eagle.jpg"
path4 = "/Users/sisyphus/darkflow/sample_img/sample_horses.jpg"
path5 = "/Users/sisyphus/darkflow/sample_img/sample_person.jpg"

flower_dict = {0:'computer',1:'dog',2:'eagle',3:'horses',4:'person'}

w=416
h=416
c=3

def read_one_image(path):
    img = io.imread(path)
    img = transform.resize(img,(w,h))
    return np.asarray(img)

with tf.Session() as sess:
    data = []
    data1 = read_one_image(path1)
    data2 = read_one_image(path2)
    data3 = read_one_image(path3)
    data4 = read_one_image(path4)
    data5 = read_one_image(path5)
    data.append(data1)
    data.append(data2)
    data.append(data3)
    data.append(data4)
    data.append(data5)

    saver = tf.train.import_meta_graph('/Users/sisyphus/darkflow/ckpt/yolo-voc-9.meta')
    saver.restore(sess,tf.train.latest_checkpoint('/Users/sisyphus/darkflow/ckpt/'))

    graph = tf.get_default_graph()
    input = graph.get_tensor_by_name("input:0")
    feed_dict = {input:data}

    logits = graph.get_tensor_by_name("output:0")

    result = sess.run(logits,feed_dict)
   
    boxes = framework.findboxes(result)

    #打印出预测矩阵
    print(result)
    #打印出预测矩阵每一行最大值的索引
    print(tf.argmax(result,1).eval())
    #根据索引通过字典对应花的分类
#    outputt = []
#    outputt = tf.argmax(classification_result,1).eval()
#    for i in range(len(outputt)):
#        print("第",i+1,"张图片预测:"+flower_dict[outputt[i]])
       
       
       
       
       
       
#运行结果:     
#[[  5.76620245   3.18228579  -3.89464641  -2.81310582   1.40294015]
# [ -1.01490593   3.55570269  -2.76053429   2.93104005  -3.47138596]
# [ -8.05292606  -7.26499033  11.70479774   0.59627819   2.15948296]
# [ -5.12940931   2.18423128  -3.33257103   9.0591135    5.03963232]
# [ -4.25288343  -0.95963973  -2.33347392   1.54485476   5.76069307]]
#[0 1 2 3 4]
#第 1 朵花预测:dasiy
#第 2 朵花预测:dandelion
#第 3 朵花预测:roses
#第 4 朵花预测:sunflowers
#第 5 朵花预测:tulips

没有评论:

发表评论

Failed to find TIFF library

ImportError: Failed to find TIFF library. Make sure that libtiff is installed and its location is listed in PATH|LD_LIBRARY_PATH|.. 解决方法: ...