TensorFlow Lite + GPUImage 实现AI背景虚化(二)

最近项目中有研究背景虚化功能,需求是通过写一个GPUImage的滤镜,结合TensorFLow Lite来实现对图片中指定物体的背景虚化功能。这部分内容基本都是通过看官方文档和自己摸索学习,这里总结并整理一份笔记,内容主要包括Android接入TensorFlow Lite,通过运行AI模型来识别图片中物体,并对其做背景虚化。一共分为三部分,本文是第二部分,用GPUImage库来实现目标物体的背景虚化。

什么是GPUImage

GPUImage是一个用来给图片加滤镜的开源框架。不止图片,也可以动态的给照相机镜头加滤镜,使得拍出来的照片自带滤镜。所以如果有做相机类app、图片的修图功能等,可以使用这个库。这个库的地址在这里:

GPUImage

是的你没看错,他确实是一个iOS的库,当然Android也有一个由他魔改而来的 android-gpuimage 。不过毕竟是从iOS搬运过来的,这个Android库存在很多坑,而且他所支持的滤镜数量也不及原版。因此在使用过程中,如果有一些实在无法解决的问题,可以参考iOS的实现,自己改一改。

这个库的引入也很简单,这里就不细说了,具体直接参考README文档即可,接下里介绍下代码中具体的使用方法。

基本的滤镜使用

使用起来也很简单,首先在xml文件中这也写:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
android:layout_width="match_parent"
android:layout_height="match_parent">

<jp.co.cyberagent.android.gpuimage.GPUImageView
android:id="@+id/gpuImageView"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_centerInParent="true"
app:gpuimage_surface_type="texture_view" />

</RelativeLayout>

这个GPUImageView就是GPUIMage库自带的一个控件,可以当做ImageView来使用。要加载图片,也是跟ImageView控件一样:

1
gpuImageView.setImage(bitmap)

然后就是给图片加滤镜了。GPUImage自带有很多滤镜,这里取其中几个举个例子,比如 GPUImageBrightnessFilter滤镜,就是调节图片亮度的,他的调节范围是-1到1,默认是0:

GPUImageBrightnessFilter: Adjusts the brightness of the image

brightness: The adjusted brightness (-1.0 - 1.0, with 0.0 as the default)

添加方式:

1
2
3
4

val filter = GPUImageBrightnessFilter()
filter.setBrightness(0.5f)//亮度调节
gpuImageView.filter = GPUImageBrightnessFilter()

最终的效果就是这样的:

分别是原图、亮度强度为0.3、亮度强度-0.3的效果。

背景虚化的原理

实现目标背景虚化的原理很简单:找出图片中的目标物体,只对目标物体之外的背景部分做模糊处理,目标物体不变。所以大体上是有两套方案:

  1. 自己写一个模糊滤镜,在滤镜中自己判断是否需要模糊
  2. 使用自带的模糊滤镜对图片整体模糊,在显示的时候再去判断如果是目标物体就显示原图,非目标物体则显示模糊图

这两种方案,显然方案1会复杂一点,方案2实现起来相对简单,我目前使用的也是方案2。后面有空的话再去研究下方案1的实现。

方案二,比较笨的办法是写两个ImageView,上层是抠出来的原图,下层是模糊的图片,不过这个写法实在有点蠢。另外一个比较好的办法就是自己写一个滤镜,这个滤镜的输入是一张原图和一个可以确定目标物体轮廓的数据,在滤镜中做背景虚化并输出,这样一来这个滤镜也方便在其他地方使用。不过因为Opengl不能直接接受数据输入,所以需要把轮廓数据转换为一张Mask图片作为纹理传入。

这个方案具体如何做呢?

  1. 模糊效果是可以直接用GPUImage自带的模糊滤镜;
  2. 轮廓可以用之前用Tensorflow Lite生成的Mask图;
  3. 自定义滤镜A,写一个能够接受多个图片作为输入源的滤镜。GPUImage是基于OpenGL ES实现的,且已经封装好了很多OpenGL的方法,我们只需要写着色器即可。对OpenGL ES不太了解的话可以先大概学习一下基本概念:LearnOpenGL
  4. 在上一步的滤镜中,将原图和目标物体轮廓图作为输入源加入到滤镜中;
  5. GPUImage提供了一个滤镜叫GPUImageFilterGroup,这是一个滤镜组,可以接收多个滤镜;所以需要重写GPUImageFilterGroup,往里面依次加入模糊滤镜和我们前面写的滤镜A,在这个滤镜的片元着色器中具体去操作像素点;

具体实现

接下来就是具体的实现方式了。

自定义 GPUImageThreeInputFilter 滤镜

首先,要写一个可以接收多个图片作为输入源的滤镜。其实这个滤镜在iOS的库里面是有的,可以接收两个、三个乃至四个图片的输入,但是Android版本的库里面只有一个仅能接收两个图片输入的 GPUImageTwoInputFilter 滤镜。而我们所需要的滤镜要能接收三张图片:加了滤镜后的图片、原图、描出目标物体的Mask图,所以需要参考 GPUImageTwoInputFilter 自己写一个。具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

open class GPUImageThreeInputFilter(vertexShader: String?, fragmentShader: String?) :
GPUImageFilter(vertexShader, fragmentShader) {
private var filterSecondTextureCoordinateAttribute = 0
private var filterInputTextureUniform2 = 0
private var filterSourceTexture2 = OpenGlUtils.NO_TEXTURE
private var texture2CoordinatesBuffer: ByteBuffer? = null
private var bitmap2: Bitmap? = null

private var filterThirdTextureCoordinateAttribute = 0
private var filterInputTextureUniform3 = 0
private var filterSourceTexture3 = OpenGlUtils.NO_TEXTURE
private var texture3CoordinatesBuffer: ByteBuffer? = null
private var bitmap3: Bitmap? = null

constructor(fragmentShader: String?) : this(VERTEX_SHADER, fragmentShader) {}

override fun onInit() {
super.onInit()
filterSecondTextureCoordinateAttribute = GLES20.glGetAttribLocation(program, "inputTextureCoordinate2")
filterInputTextureUniform2 = GLES20.glGetUniformLocation(program, "inputImageTexture2")
// This does assume a name of "inputImageTexture2" for second input texture in the fragment shader
GLES20.glEnableVertexAttribArray(filterSecondTextureCoordinateAttribute)

filterThirdTextureCoordinateAttribute = GLES20.glGetAttribLocation(program, "inputTextureCoordinate3")
filterInputTextureUniform3 = GLES20.glGetUniformLocation(program, "inputImageTexture3")
// This does assume a name of "inputImageTexture3" for second input texture in the fragment shader
GLES20.glEnableVertexAttribArray(filterThirdTextureCoordinateAttribute)
}

override fun onInitialized() {
super.onInitialized()
if (bitmap2 != null && !bitmap2!!.isRecycled && bitmap3 != null && !bitmap3!!.isRecycled) {
setBitmap(bitmap2, bitmap3)
}
}

fun setBitmap(bitmap2: Bitmap?, bitmap3: Bitmap?) {
if ((bitmap2 != null && bitmap2.isRecycled)
&& (bitmap3 != null && bitmap3.isRecycled)) {
return
}
this.bitmap2 = bitmap2
this.bitmap3 = bitmap3

if (this.bitmap2 == null || this.bitmap3 == null) {
return
}
runOnDraw(Runnable {
if (filterSourceTexture2 == OpenGlUtils.NO_TEXTURE
&& filterSourceTexture3 == OpenGlUtils.NO_TEXTURE) {
if ((bitmap2 == null || bitmap2.isRecycled)
|| (bitmap3 == null || bitmap3.isRecycled)) {
return@Runnable
}
GLES20.glActiveTexture(GLES20.GL_TEXTURE3)
filterSourceTexture2 = OpenGlUtils.loadTexture(bitmap2, OpenGlUtils.NO_TEXTURE, false)
GLES20.glActiveTexture(GLES20.GL_TEXTURE4)
filterSourceTexture3 = OpenGlUtils.loadTexture(bitmap3, OpenGlUtils.NO_TEXTURE, false)
}
})
}

override fun onDestroy() {
super.onDestroy()
GLES20.glDeleteTextures(1, intArrayOf(filterSourceTexture2), 0)
filterSourceTexture2 = OpenGlUtils.NO_TEXTURE

GLES20.glDeleteTextures(1, intArrayOf(filterSourceTexture3), 0)
filterSourceTexture3 = OpenGlUtils.NO_TEXTURE
}

override fun onDrawArraysPre() {

GLES20.glEnableVertexAttribArray(filterSecondTextureCoordinateAttribute)
GLES20.glActiveTexture(GLES20.GL_TEXTURE3)
GLES20.glBindTexture(GLES20.GL_TEXTURE_2D, filterSourceTexture2)
GLES20.glUniform1i(filterInputTextureUniform2, 3)
texture2CoordinatesBuffer!!.position(0)
GLES20.glVertexAttribPointer(filterSecondTextureCoordinateAttribute, 2, GLES20.GL_FLOAT, false, 0, texture2CoordinatesBuffer)

GLES20.glEnableVertexAttribArray(filterThirdTextureCoordinateAttribute)
GLES20.glActiveTexture(GLES20.GL_TEXTURE4)
GLES20.glBindTexture(GLES20.GL_TEXTURE_2D, filterSourceTexture3)
GLES20.glUniform1i(filterInputTextureUniform3, 4)
texture3CoordinatesBuffer!!.position(0)
GLES20.glVertexAttribPointer(filterThirdTextureCoordinateAttribute, 2, GLES20.GL_FLOAT, false, 0, texture3CoordinatesBuffer)
}

fun setRotation2(rotation: Rotation?, flipHorizontal: Boolean, flipVertical: Boolean) {
val buffer = TextureRotationUtil.getRotation(rotation, flipHorizontal, flipVertical)
val bBuffer = ByteBuffer.allocateDirect(32).order(ByteOrder.nativeOrder())
val fBuffer = bBuffer.asFloatBuffer()
fBuffer.put(buffer)
fBuffer.flip()
texture2CoordinatesBuffer = bBuffer
}

fun setRotation3(rotation: Rotation?, flipHorizontal: Boolean, flipVertical: Boolean) {
val buffer = TextureRotationUtil.getRotation(rotation, flipHorizontal, flipVertical)
val bBuffer = ByteBuffer.allocateDirect(32).order(ByteOrder.nativeOrder())
val fBuffer = bBuffer.asFloatBuffer()
fBuffer.put(buffer)
fBuffer.flip()
texture3CoordinatesBuffer = bBuffer
}

companion object {
private const val VERTEX_SHADER =
"attribute vec4 position;\n" +
"attribute vec4 inputTextureCoordinate;\n" +
"attribute vec4 inputTextureCoordinate2;\n" +
"attribute vec4 inputTextureCoordinate3;\n" +
" \n" +
"varying vec2 textureCoordinate;\n" +
"varying vec2 textureCoordinate2;\n" +
"varying vec2 textureCoordinate3;\n" +
" \n" +
"void main()\n" +
"{\n" +
" gl_Position = position;\n" +
" textureCoordinate = inputTextureCoordinate.xy;\n" +
" textureCoordinate2 = inputTextureCoordinate2.xy;\n" +
" textureCoordinate3 = inputTextureCoordinate3.xy;\n" +
"}"
}

init {
setRotation2(Rotation.NORMAL, false, false)
setRotation3(Rotation.NORMAL, false, false)
}
}

说是接受三个图片作为输入,但是为什么 setBitmap() 方法只接受两个参数呢?其实这个滤镜在通过· setFilter() 或其他方式加到图片上之后,就会默认把当前状态的图片作为一个输入源。在顶点着色器里面的 inputTextureCoordinate 就是代表了默认的输入源坐标。

自定义 BokehBluFilter 滤镜

在上篇文章中,我们已经获取到了一张由目标物体的轮廓绘制出的Mask图。以猫为例,假如我们绘制出的Mask图中,猫的部分红色,非猫部分黑色,即猫的部分色值 r=255,g=0,b=0 , 非猫的背景部分色值为 r=0,g=0,b=0,那么效果如下:

原图:

结果图:

然后自定义一个 BokehBlurFilter,继承自 GPUImageFilterGroup,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

class BokehBlurFilter(val original: Bitmap? = null, segMask: Bitmap? = null) : GPUImageFilterGroup() {

init {
if (original != null && segMask != null) {
val threeInputFilter = GPUImageThreeInputFilter(fragmentShader)

threenputFilter.setBitmap(original, segMask)
addFilter(GPUImageGaussianBlurFilter())
addFilter(threeInputFilter)
}
}

fun setBlurSize(intensity: Float) {
(filters[0] as GPUImageGaussianBlurFilter).setBlurSize(intensity * MAX_BLUR_SIZE)
updateMergedFilters()
}

companion object {

var fragmentShader =
"varying highp vec2 textureCoordinate;\n" +
"varying highp vec2 textureCoordinate2;\n" +
"varying highp vec2 textureCoordinate3;\n" +
"\n" +
" uniform sampler2D inputImageTexture;\n" +
" uniform sampler2D inputImageTexture2;\n" +
" uniform sampler2D inputImageTexture3;\n" +
" \n" +
" void main()\n" +
" {\n" +
" mediump vec4 textureBlur = texture2D(inputImageTexture, textureCoordinate);\n" +
" mediump vec4 textureOriginal = texture2D(inputImageTexture2, textureCoordinate2);\n" +
" mediump vec4 textureMask = texture2D(inputImageTexture3, textureCoordinate3);\n" +
" gl_FragColor = textureMask;\n" +
" }"
}

}

代码很简单,就是继承了 GPUImageFilterGroup,在初始化时向父类所持有的滤镜列表中添加两个滤镜,第一个是模糊滤镜,这里我用了自带的高斯模糊滤镜;第二个是前面写的 GPUImageThreeInputFilter 滤镜,并且在 GPUImageThreeInputFilter 滤镜中依次把原图和目标物体轮廓Mask图作为输入源。这里在创建 GPUImageThreeInputFilter 对象时传入了一个 fragmentShader 参数作为片元着色器。在片元着色器对每个像素点做具体操作。这里只是举个例子,显示了Mask图。了解了glsl语言就知道这里该怎么写。我们的背景虚化具体就是在这里做处理了。

片元着色器

在片元着色器里面这样写就可以了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

const val fragmentShader =
"varying highp vec2 textureCoordinate;\n" +
"varying highp vec2 textureCoordinate2;\n" +
"varying highp vec2 textureCoordinate3;\n" +
"\n" +
" uniform sampler2D inputImageTexture;\n" +
" uniform sampler2D inputImageTexture2;\n" +
" uniform sampler2D inputImageTexture3;\n" +
" \n" +
" void main()\n" +
" {\n" +
" mediump vec4 textureBlur = texture2D(inputImageTexture, textureCoordinate);\n" +
" mediump vec4 textureMask = texture2D(inputImageTexture2, textureCoordinate2);\n" +
" mediump vec4 textureOriginal = texture2D(inputImageTexture3, textureCoordinate3);\n" +
" if(textureMask.r == 0.0){\n" +
" gl_FragColor = textureOriginal;\n" +
" }else{\n" +
" gl_FragColor = textureBlur;\n" +
" }\n" +
" }"

很简单,textureBlur 是模糊后图片的纹理, textureMask 是目标物体轮廓Mask图的纹理, textureOriginal 是原图的纹理。以前面那个猫的图片为例,因为Mask图中,猫的部分是黑色,非猫部分是红色,也就是说猫的部分,像素点的r为0,非猫部分则为1,所以在着色器中根基r值是否为0来判断当前点显示的纹理是原图还是模糊图。

最后,将这个滤镜设置给原图,效果如下:

可以看到,模糊对猫的背景虚化是成功了,但是还存在一些问题,比如:猫的边缘比较生硬等,另外在实际使用过程中也还是有一些坑。下一篇文章中具体来总结一下遇到的问题和优化方案。