目录

  1. 代码
  2. Reference
  3. END

很常见的一个题目:给定升序排序数组 a,查找 a 中大于等于 x 的第一个数的索引。

有多种方法来完成这个题目,大部分人第一反应都是二分查找。我喜欢从实践出发来想问题。我有几个问题:

  1. 二分查找真的很快吗?如果快,快多少?
  2. 实际中我们应该怎么写才能更好地完成这个题目?

本文就是来回答这两个问题的。我不会在本文中赘述二分查找的思想原理,不了解的同学可以先出门搜索了解以下。

来点硬的,直接来跑一下结果。你可以在 colab 中自己试验一下。

代码

我们这里测试了 4 种方法,如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def recursive(a, low, high, x):
'''递归写法'''
middle = int((low + high) / 2)
if high >= low:
if a[middle] == x:
return middle
elif a[middle] > x:
return recursive(a, low, middle - 1, x)
else:
return recursive(a, middle + 1, high, x)
if x < a[middle]:
return middle
if x > a[middle]:
return middle + 1


def whileloop(a, x):
'''while 循环写法'''
low = 0
high = len(a) - 1
while low <= high:
middle = int((low + high) / 2)
if a[middle] < x:
low = middle + 1
elif a[middle] > x:
high = middle - 1
else:
return middle
if x < a[middle]:
return middle
if x > a[middle]:
return middle + 1


def forloop(a, x):
'''for 循环写法'''
for idx, item in enumerate(a):
if item >= x:
return idx


def builtin_bisect(a, x):
'''使用内置的 bisect'''
return bisect.bisect_left(a, x)

共进行 3 次测试,各次 a 的长度为 $10^3$、$10^6$ 和 $10^9$,x 为长度的 0.67,测试 CPU 环境为 Intel(R) Xeon(R) CPU @ 2.00GHz,Python 版本为 Python 3.6.8,就是 colab 的环境:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
recursive_times = []
whileloop_times = []
forloop_times = []
builtin_times = []
lengths = [int(1e3), int(1e6), int(1e9)]
for i in lengths:
a = range(i)
x = 0.67 * i
print(i)
start = time.time()
_ = recursive(a, 0, i - 1, x)
cost = time.time() - start
recursive_times.append(cost)

start = time.time()
_ = whileloop(a, x)
cost = time.time() - start
whileloop_times.append(cost)

start = time.time()
_ = forloop(a, x)
cost = time.time() - start
forloop_times.append(cost)

start = time.time()
_ = builtin_bisect(a, x)
cost = time.time() - start
builtin_times.append(cost)

fig = plt.figure(figsize=(15, 8))
plt.plot(range(3), [i * 1000 for i in recursive_times], '-o', label="recursive")
plt.plot(range(3), [i * 1000 for i in whileloop_times], '-o', label="whileloop")
plt.plot(range(3), [i * 1000 for i in forloop_times], '-o', label="forloop")
plt.plot(range(3), [i * 1000 for i in builtin_times], '-o', label="builtin_bisect")
plt.xticks([0, 1, 2], [3, 6, 9])
plt.xlabel("length/10^")
plt.ylabel("COST/ms")
plt.legend()
plt.grid(True)

测试结果:

测试结果测试结果

可以看到随着数组长度的扩大,for 循环方法消耗的时间也非常非常快的增长,而其他方法波动很小。具体而言,数组长度变为原来的 1000 倍,for 循环方法耗时也变为原来的 1000 倍左右,而其他方法的耗时变为原来的 1-3 倍左右。

暂且除去 for 循环,如果我们细看在 $10^3$、$10^6$ 和 $10^9$ 处的图,可以看到剩余三个方法之间的差距,递归耗时均最高:

实际上 bisect 内部实现用的就是 while 循环的方法,代码很短,我直接贴过来(吐槽下官方代码竟然没有很好的格式化):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def bisect_left(a, x, lo=0, hi=None):
"""Return the index where to insert item x in list a, assuming a is sorted.
The return value i is such that all e in a[:i] have e < x, and all e in
a[i:] have e >= x. So if x already appears in the list, a.insert(x) will
insert just before the leftmost x already there.
Optional args lo (default 0) and hi (default len(a)) bound the
slice of a to be searched.
"""

if lo < 0:
raise ValueError('lo must be non-negative')
if hi is None:
hi = len(a)
while lo < hi:
mid = (lo+hi)//2
if a[mid] < x: lo = mid+1
else: hi = mid
return lo

详细的测试时间数据如下:

1
2
3
4
recursive_times=[1.5735626220703125e-05, 2.5272369384765625e-05, 4.553794860839844e-05]
whileloop_times=[6.67572021484375e-06, 1.1205673217773438e-05, 1.8358230590820312e-05]
forloop_times=[4.982948303222656e-05, 0.058376312255859375, 58.462812423706055]
builtin_times=[6.4373016357421875e-06, 1.9311904907226562e-05, 2.4557113647460938e-05]

到此时我们回答了第一个问题,二分查找真的很快! 在数组长度较小时,差距还不是那么的明显。但是随着数组长度扩大,差距简直指数级扩大,差距甚至几百万倍。

至于第二个问题,我认为的答案是尽可能使用内置库,这些都是经过了很多优化,速度和稳定性都有保证,写起来还简单粗暴,何乐而不为呢?

写完感觉一篇废话 😂

Reference

END