2018年5月3日星期四

利用darkflow训练好的模型进行推导

目前先解决单张图片推导

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed May  2 17:06:39 2018

@author: sisyphus
"""

from skimage import io,transform
import tensorflow as tf
import numpy as np
import os
import time
import pickle
from multiprocessing.pool import ThreadPool
import cv2
from darkflow.utils.box import BoundBox
from darkflow.cython_utils.cy_yolo2_findboxes import box_constructor
import json


meta = {'net': {'type': '[net]', 'batch': 64, 'subdivisions': 8, 'height': 416, 'width': 416, 'channels': 3, 'momentum': 0.9, 'decay': 0.0005, 'angle': 0, 'saturation': 1.5, 'exposure': 1.5, 'hue': 0.1, 'learning_rate': 0.0001, 'max_batches': 45000, 'policy': 'steps', 'steps': '100,25000,35000', 'scales': '10,.1,.1'}, 'type': '[region]', 'anchors': [1.08, 1.19, 3.42, 4.41, 6.63, 11.38, 9.42, 5.11, 16.62, 10.52], 'bias_match': 1, 'classes': 6, 'coords': 4, 'num': 5, 'softmax': 1, 'jitter': 0.2, 'rescore': 1, 'object_scale': 5, 'noobject_scale': 1, 'class_scale': 1, 'coord_scale': 1, 'absolute': 1, 'thresh': 0.1, 'random': 0, 'model': 'cfg/yolo-voc-6c.cfg', 'inp_size': [416, 416, 3], 'out_size': [13, 13, 55], 'name': 'yolo-voc-6c', 'labels': ['crazing', 'inclusion', 'patches', 'pitted_surface', 'rolled-in_scale', 'scratches'], 'colors': [(254.0, 254.0, 254), (222.25, 190.5, 127), (190.5, 127.0, 254), (158.75, 63.5, 127), (127.0, 254.0, 254), (95.25, 190.5, 127)]}

def findboxes1(meta, net_out):#yolov2/predict
# meta
meta = meta
boxes = list()
boxes = box_constructor(meta,net_out)
return boxes

def process_box1(b, h, w, threshold):
max_indx = np.argmax(b.probs)
max_prob = b.probs[max_indx]
label = meta['labels'][max_indx]
if max_prob > threshold:
left  = int ((b.x - b.w/2.) * w)
right = int ((b.x + b.w/2.) * w)
top   = int ((b.y - b.h/2.) * h)
bot   = int ((b.y + b.h/2.) * h)
if left  < 0    :  left = 0
if right > w - 1: right = w - 1
if top   < 0    :   top = 0
if bot   > h - 1:   bot = h - 1
mess = '{}'.format(label)
return (left, right, top, bot, mess, max_indx, max_prob)
return None

def postprocess1(net_out, im, meta, path1, outpath, save = True):
"""
Takes net output, draw net_out, save to disk
"""
boxes = findboxes1(meta, net_out)

# meta
meta = meta
threshold = meta['thresh']
colors = meta['colors']
labels = meta['labels']
if type(im) is not np.ndarray:
imgcv = cv2.imread(im)
else: imgcv = im
h, w, _ = imgcv.shape

resultsForJSON = []
for b in boxes:
boxResults = process_box1(b, h, w, threshold)
if boxResults is None:
continue
left, right, top, bot, mess, max_indx, confidence = boxResults
area=(bot-top)*(right-left)#####+
thick = int((h + w) // 300)
resultsForJSON.append({"label": mess, "confidence": float('%.2f' % confidence), "topleft": {"x": left, "y": top}, "bottomright": {"x": right, "y": bot},"area":area})

cv2.rectangle(imgcv,
(left, top), (right, bot),
colors[max_indx], thick)
cv2.putText(imgcv, mess, (left, top - 12),
0, 1e-3 * h, colors[max_indx],thick//3)

if not save: return imgcv###########

outfolder = os.path.join(outpath, 'output1')   
img_name = os.path.join(outfolder, os.path.basename(path1))
cv2.imwrite(img_name, imgcv)#####   
if True:
if resultsForJSON == []:
print('Normal\n')
# return 'Normal'
else:
textJSON = json.dumps(resultsForJSON)
textFile = os.path.splitext(img_name)[0] + ".json"
print(textFile)
print('\n')
with open(textFile, 'w') as f:
f.write(textJSON)
# return textJSON


def read_one_image(path):
    img = io.imread(path) 
    imsz = cv2.resize(img, (416, 416))
    imsz = imsz / 255.
    imsz = imsz[:,:,::-1]
    return imsz   
   
   
   
path1 = "/Users/sisyphus/darkflow/sample_img/000600.jpg"#原始图片存放地址
outpath = "/Users/sisyphus/darkflow/sample_img/"#检测结果存放地址
im = io.imread(path1)

w=416
h=416
c=3 

with tf.Session() as sess:
    data = []
    data1 = read_one_image(path1)
    data.append(data1)

    saver = tf.train.import_meta_graph('/Users/sisyphus/darkflow/ckpt/yolo-voc-6c-61125.meta')
    saver.restore(sess,tf.train.latest_checkpoint('/Users/sisyphus/darkflow/ckpt/'))
    graph = tf.get_default_graph()
    input = graph.get_tensor_by_name("input:0")
    print(input)
    feed_dict = {input:data}
    logits = graph.get_tensor_by_name("output:0")
    print(logits)
    result = sess.run(logits,feed_dict)
    netout = np.squeeze(result, axis=(0,))
    postprocess1(netout, im, meta, path1, outpath)
 

   
   

没有评论:

发表评论

Failed to find TIFF library

ImportError: Failed to find TIFF library. Make sure that libtiff is installed and its location is listed in PATH|LD_LIBRARY_PATH|.. 解决方法: ...