加载训练好的模型
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Apr 22 13:52:11 2018
@author: sisyphus
"""
from skimage import io,transform
import tensorflow as tf
import numpy as np
path1 = "/Users/sisyphus/darkflow/sample_img/sample_computer.jpg"
path2 = "/Users/sisyphus/darkflow/sample_img/sample_dog.jpg"
path3 = "/Users/sisyphus/darkflow/sample_img/sample_eagle.jpg"
path4 = "/Users/sisyphus/darkflow/sample_img/sample_horses.jpg"
path5 = "/Users/sisyphus/darkflow/sample_img/sample_person.jpg"
flower_dict = {0:'computer',1:'dog',2:'eagle',3:'horses',4:'person'}
w=416
h=416
c=3
def read_one_image(path):
img = io.imread(path)
img = transform.resize(img,(w,h))
return np.asarray(img)
with tf.Session() as sess:
data = []
data1 = read_one_image(path1)
data2 = read_one_image(path2)
data3 = read_one_image(path3)
data4 = read_one_image(path4)
data5 = read_one_image(path5)
data.append(data1)
data.append(data2)
data.append(data3)
data.append(data4)
data.append(data5)
saver = tf.train.import_meta_graph('/Users/sisyphus/darkflow/ckpt/yolo-voc-9.meta')
saver.restore(sess,tf.train.latest_checkpoint('/Users/sisyphus/darkflow/ckpt/'))
graph = tf.get_default_graph()
input = graph.get_tensor_by_name("input:0")
feed_dict = {input:data}
logits = graph.get_tensor_by_name("output:0")
result = sess.run(logits,feed_dict)
boxes = framework.findboxes(result)
#打印出预测矩阵
print(result)
#打印出预测矩阵每一行最大值的索引
print(tf.argmax(result,1).eval())
#根据索引通过字典对应花的分类
# outputt = []
# outputt = tf.argmax(classification_result,1).eval()
# for i in range(len(outputt)):
# print("第",i+1,"张图片预测:"+flower_dict[outputt[i]])
#运行结果:
#[[ 5.76620245 3.18228579 -3.89464641 -2.81310582 1.40294015]
# [ -1.01490593 3.55570269 -2.76053429 2.93104005 -3.47138596]
# [ -8.05292606 -7.26499033 11.70479774 0.59627819 2.15948296]
# [ -5.12940931 2.18423128 -3.33257103 9.0591135 5.03963232]
# [ -4.25288343 -0.95963973 -2.33347392 1.54485476 5.76069307]]
#[0 1 2 3 4]
#第 1 朵花预测:dasiy
#第 2 朵花预测:dandelion
#第 3 朵花预测:roses
#第 4 朵花预测:sunflowers
#第 5 朵花预测:tulips
没有评论:
发表评论