PyAV で「Viewing the CaffeNet model’s predictation and its target video at the same time. 」な動画を作る、の巻

ここ一ヶ月くらいのネタは元はと言えばこれが「目先の」ゴールだったりした。

2017-07-28追記: ここに書いてるコードは少し問題ありだけれどもそのままにしている。どう問題だったのかについてはこちらに書いておいた。

wavefile の扱いだったり、FFT だったり、そもそもがシンプルな画像処理(Pillow)だったり、そして ffmpeg だったり。これらは全て「機械学習にまつわるデータ処理のエクササイズ」、であった。

深層学習に限らず機械学習の分野は、個人的に「お世話になりそう」というところまでは何度も行ったのに、検証時間不足やらで結局提案や実現に至らず、要するにこれまで「やらずじまい」だったのね。ベイズ回帰程度な、実際に活用まで行ったのは。

てわけでもう何年も「やりたい」と思いながら深くやる機会も時間もなく、だけど時間も出来たことだしようやく本格的に遊び始めてみよう、てのがワタシの今の状態。丁度今深層学習がブームでもあるしね。てわけで、本まで買って勉強し始めたとこ。

インフラについては現状「評価云々」出来るレベルにはないので、ひとまず chainer で遊び始めている。

で、理解のための「入り口」としてはさ、「学習させるところから始める」ことではないと思ったのね、ワタシは。「学習済みの結果を使って予測させてみる」ことを実際に試みてみれば、「自分で全部やる場合のゴールのイメージ」をつかみやすかろうと。てわけで最初に選んだお題がこれ。この evaluate_caffe_net.py をチマチマ静止画相手に動かして按配をみていたんだけれども…。

実際「静止画相手」てのが結構大変で、「色んな静止画をお取り寄せる」ことがもう、あんまり理想的な勉強環境ではないのね。いやいや動画相手の方が手っ取り早いっしょ、と。(実際「学習させるデータ」としても、うまくすれば理想かもしらんし。)

そんな理由で白羽の矢を立てたのが PyAV。直接フレームを取り出せて、書き出しも出来る。だからフレーム単位に予測させて、その結果を動画に埋め込めるであろう、と。

てわけでやってみたのである。

evaluate_caffe_net.pyを実際に動かしてみればわかるんだけれど、モデルの読み込みに非常に時間がかかる。3分くらいかしら? caffenet model などは皆でよってたかって育てているモデルだから、凄まじくデカい。bvlc_reference_caffenet.caffemodel は 233MB。3分かかるのはこの大きさのせいだけではなかろうが、とにかくこのバッチスクリプトの構造のままでは、ワタシがやりたいことをやるのはむつかしい(特に試行錯誤を伴うので)。

なので、このバッチを単純化して CaffeNet model に限定したものを「サービス」化しておく:

predict_server.py
  1 import sys
  2 import logging
  3 import struct
  4 import SocketServer  # python 2.7
  5 from io import BytesIO, StringIO
  6 from textwrap import TextWrapper
  7 import json
  8 
  9 import numpy as np
 10 from PIL import Image, ImageOps
 11 
 12 import chainer
 13 import chainer.functions as F
 14 from chainer.links import caffe
 15 
 16 if __name__ == "__main__":
 17     logging.basicConfig(stream=sys.stderr, level=logging.INFO)
 18 
 19     #
 20     chainer.config.train = False  # All the codes will run in test mode
 21     logging.info("Loading synset_words.")
 22     categories = np.loadtxt('synset_words.txt', str, delimiter='\t')
 23     logging.info("Loaded synset_words.")
 24     logging.info("Loading mean image file.")
 25     mean_image = np.load("ilsvrc_2012_mean.npy")
 26     logging.info("Loaded mean image file.")
 27     logging.info("Loading caffenet model.")
 28     func = caffe.CaffeFunction("bvlc_reference_caffenet.caffemodel")
 29     logging.info("Loaded caffenet model.")
 30     in_size = 227
 31     def predict(x):
 32         y, = func(inputs={'data': x}, outputs=['fc8'])
 33         return F.softmax(y)
 34     start = (256 - in_size) // 2
 35     stop = start + in_size
 36     mean_image = mean_image[:, start:stop, start:stop].copy()
 37     x_batch = np.ndarray((1, 3, in_size, in_size), dtype=np.float32)
 38     #
 39     txtwrap = TextWrapper(width=50, subsequent_indent=" " * 9).fill
 40     #
 41 
 42     def _resize_to_square(img, fill=0):
 43         if img.width > img.height:
 44             border = (0, (img.width - img.height) // 2)
 45         elif img.width < img.height:
 46             border = ((img.height - img.width) // 2, 0)
 47         else:
 48             return img.copy()
 49         return ImageOps.expand(img, border, fill)
 50 
 51     def _image_pred(img):
 52         forpred = np.asarray(img.resize((256, 256))).transpose(2, 0, 1)[::-1]
 53         forpred = forpred[:, start:stop, start:stop].astype(np.float32)
 54         forpred -= mean_image
 55         x_batch[0] = forpred
 56         out = StringIO()
 57         with chainer.no_backprop_mode():
 58             score = predict(np.asarray(x_batch))
 59             prediction = zip(score.data[0].tolist(), categories)
 60             prediction.sort(cmp=lambda r, sn: cmp(r[0], sn[0]), reverse=True)
 61             for rank, (score, name) in enumerate(prediction[:7], start=1):
 62                 cid, _, cname = name.partition(" ")
 63                 out.write(txtwrap(u"%5.1f%% | %s" % (score * 100, cname)))
 64                 out.write(u"\n")
 65         return out.getvalue().encode()
 66 
 67     class ChainerPredictorHandler(SocketServer.BaseRequestHandler):
 68         def _get_img(self):
 69             imgbytessize = struct.unpack("!I", self.request.recv(4))[0]
 70             logging.info("imgbytessize={}".format(imgbytessize))
 71             #
 72             bimg = BytesIO()
 73             chunk_size = 1024
 74             read = 0
 75             while read < imgbytessize:
 76                 b = self.request.recv(chunk_size)
 77                 read += len(b)
 78                 bimg.write(b)
 79             #
 80             bimg.seek(0)
 81             return Image.open(bimg)
 82 
 83         def handle(self):
 84             # protocol:
 85             #   client send size of image
 86             #   client send image bytes
 87             #   server send text of predictation
 88             logging.debug("begin handle request.")
 89             img = self._get_img()
 90             # ----------------------------------------
 91             txt = json.dumps((
 92                 _image_pred(img),
 93                 _image_pred(_resize_to_square(img))))
 94             # ----------------------------------------
 95             self.request.sendall(txt)
 96             logging.info("send result {} bytes".format(len(txt)))
 97             logging.debug("end handle request.")
 98 
 99     HOST, PORT = "localhost", 8988
100     server = SocketServer.TCPServer((HOST, PORT), ChainerPredictorHandler)
101     logging.info("start service.")
102     server.serve_forever()

このサービスはソケットから「イメージのサイズ」と「イメージのバイト列」を受け取って、その予測結果を(2種類)返す。

evaluate_caffe_net.py を書き換えるだけではこうはならないが、「予測」については初心者向けの本を参考に書き換えた。

ごちゃごちゃしてるうちの一つ、「_resize_to_square」はワタシの素朴な疑問を試してみたくてやってる。つまり、モデルに合わせて (256, 256) にリサイズする必要があるんだけれど、「アスペクト比を維持しなくていいの?」てこと。だからアスペクト比を維持するのとしないのとでの違いを見てみたかった。「_resize_to_square」が具体的にどんな画像を作り出すのかはまぁ読めばわかるとも思うけれど、個人的に何度も必要になりそうな予感がしたので、ここに例を書いておいた

ともあれ本題のクライアントを書く前に実験用クライアント:

hoge.py
 1 import socket
 2 import sys
 3 import struct
 4 from io import BytesIO
 5 from PIL import Image
 6 
 7 HOST, PORT = "localhost", 8988
 8 
 9 # Create a socket (SOCK_STREAM means a TCP socket)
10 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
11 
12 try:
13     # Connect to server and send data
14     sock.connect((HOST, PORT))
15     img = Image.open(sys.argv[0])
16     out = BytesIO()
17     img.save(out, "PNG")
18     imgbytes = out.getvalue()
19     sock.sendall(struct.pack('!I', len(imgbytes)))
20     sock.sendall(imgbytes)
21 
22     # Receive data from the server and shut down
23     import json
24     for txt in json.loads(sock.makefile('r').readlines()[0]):
25         print(txt)
26 finally:
27     sock.close()

例えばこんな出力:

 1  68.6% | rapeseed
 2  14.9% | maze, labyrinth
 3   1.8% | stone wall
 4   1.6% | picket fence, paling
 5   1.0% | worm fence, snake fence, snake-rail
 6          fence, Virginia fence
 7   0.9% | lakeside, lakeshore
 8   0.9% | fountain
 9 
10  64.5% | rapeseed
11   4.7% | hay
12   3.0% | lakeside, lakeshore
13   2.6% | stone wall
14   2.2% | fountain
15   1.9% | dishrag, dishcloth
16   1.8% | picket fence, paling

てわけで、これを元に、本題の「動画を入力にし、予測結果を動画にそのまま書き込む」なクライアント:

  1 import socket
  2 import sys
  3 import struct
  4 import argparse
  5 import signal
  6 import logging
  7 import json
  8 from io import BytesIO
  9 from multiprocessing import Process, Queue
 10 from Queue import Empty
 11 
 12 from PIL import Image, ImageFont, ImageDraw
 13 import av
 14 
 15 def _run(args, q):
 16     def _IntHandler(signum, frame):
 17         q.put("done")
 18 
 19     signal.signal(signal.SIGINT, _IntHandler)
 20 
 21     logging.basicConfig(stream=sys.stderr, level=logging.INFO)
 22     HOST, PORT = "localhost", 8988
 23 
 24     #
 25 
 26     fnt = ImageFont.truetype('couri.ttf', 26)
 27 
 28     icntnr = av.open(args.inputpath)
 29     ocntnr = av.open(args.inputpath + ".out.mp4", "w")
 30 
 31     ivstrm = next(s for s in icntnr.streams if s.type == b'video')
 32     iastrm = next(s for s in icntnr.streams if s.type == b'audio')
 33     ostrms = {
 34         "audio": ocntnr.add_stream(codec_name=iastrm.codec.name, rate=iastrm.rate),
 35         "video": ocntnr.add_stream(codec_name=ivstrm.codec.name, rate=ivstrm.rate),
 36         }
 37     ostrms["video"].width = ivstrm.width
 38     ostrms["video"].height = ivstrm.height
 39 
 40     if args.start_sec:
 41         seek_pts_v = int(args.start_sec / float(ivstrm.time_base) + ivstrm.start_time)
 42         seek_pts_a = int(args.start_sec / float(iastrm.time_base) + iastrm.start_time)
 43         iastrm.seek(seek_pts_a)
 44         ivstrm.seek(seek_pts_v)
 45 
 46     def _get_pred(img):
 47         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 48         sock.connect((HOST, PORT))
 49         out = BytesIO()
 50         img.save(out, "PNG")
 51         imgbytes = out.getvalue()
 52         sock.sendall(struct.pack('!I', len(imgbytes)))
 53         sock.sendall(imgbytes)
 54         txt = sock.makefile('r').readlines()[0]
 55         sock.close()
 56         return json.loads(txt)
 57 
 58     count = 0
 59     for packet in icntnr.demux():
 60         for ifr in packet.decode():
 61             try:
 62                 r = q.get(block=False, timeout=1/500.)
 63                 if r:
 64                     ocntnr.close()
 65                     return
 66             except Empty as e:
 67                 pass
 68 
 69             typ = packet.stream.type
 70             ifr.pts = None
 71             if typ == 'video':
 72                 img = ifr.to_image()
 73                 tmpdctx = ImageDraw.Draw(img)
 74                 # --------------------------------------------
 75                 if count % args.step == 0:
 76                     txts = _get_pred(img)
 77                     logging.debug(txts)
 78                 # --------------------------------------------
 79                 x = 30
 80                 for txt in txts:
 81                     txtsz = tmpdctx.multiline_textsize(txt, fnt)
 82                     osd = Image.new("RGB", (txtsz[0] + 10, txtsz[1] + 10), "white")
 83                     dctx = ImageDraw.Draw(osd)
 84                     dctx.multiline_text((5, 5), txt, font=fnt, fill="black")
 85                     del dctx
 86                     img.paste(
 87                         osd,
 88                         box=(x, 30, osd.size[0] + x, osd.size[1] + 30),
 89                         mask=Image.new("L", osd.size, 192))
 90                     x += txtsz[0] + 50
 91                 del tmpdctx
 92 
 93                 ofr = av.VideoFrame.from_image(img)
 94                 for p in ostrms[typ].encode(ofr):
 95                     ocntnr.mux(p)
 96                 if count % args.step == 0:
 97                     logging.info("count={}".format(count))
 98                 count += 1
 99             else:
100                 for p in ostrms[typ].encode(ifr):
101                     ocntnr.mux(p)
102             if args.count and args.count <= count:
103                 ocntnr.close()
104                 return
105                 
106     ocntnr.close()
107 
108 
109 if __name__ == '__main__':
110     parser = argparse.ArgumentParser()
111     parser.add_argument("inputpath")
112     parser.add_argument("--count", type=int, default=0)
113     parser.add_argument("--step", type=int, default=30)
114     parser.add_argument("--start_sec", type=int, default=0)
115     args = parser.parse_args()
116 
117     q = Queue()
118     p = Process(target=_run, args=(args, q,))
119     p.start()
120     p.join()

こちらは Ctrl-C で止めてもちゃんと動画がちゃんとするように、とか、シークとか、「何フレームおきに予測させるか」などのことでゴチャついてるが、要するにこういったことをしないと、「べらぼーな時間がかかる」がために、気軽な試行錯誤が出来なくなってしまう。なお、本日時点での PyAV のマスターブランチを使うとこのプログラムは動作しない。PyAV の(生半可な)紹介にはそこらへんの事情は書いておいたんで、同じことをしてみたい人は注意。

てわけで、一つだけ試みてみた:

ワタシの「素朴な疑問」についてはあんまし良くわかんないんだよねぇ。「cockroach」な部分なんかはアスペクト比維持の方が若干いいのかな、なんてのもあるけれど、全体では別にどっちでもよさげな感じもする。要はどんな入力を使って学習させているか、てことだと思うんだけれど、「アスペクト比にあまり依存しないように」なんてことしてるのかしら?

やってみて思ったんだけど、こうやって予測と入力を同時に見れるようにしとくと、こういう検証しやすいよね、てことね。例えば前処理の有用性とかを知りたい、なんて目的にも一発だろう。

あとチョロっと上で書いたけど、「学習(トレーニング)を動画で」も、ひょっとしたら悪いアプローチではないかもしれなくて、これは「字幕」(subtitile)に正解を書き込んでおくことで、なんかイケそうな気がする。一応 PyAV、subtitle を読むことは出来る(書きは出来ないけど)。無論その subtitile 編集は間違いなく手間だけれど、そもそも「手間なく学習させる術」なんてのはないんだから、その中ではかなり有望そうだなぁ、と、感覚的には思う。

それとさ、これを発展させてヒートマップなんかを作れたら面白いんだけどなぁ、と思っている。けど初学者にはこれは当然まだツラい。