【Numba】加速计算

一、Numba是什么🐍:

Numba可以将numpy的代码’即时编译’成机器码,以获得近似原生机器码的计算速度

Numba is a just-in-time compiler for Python that works best on code that uses NumPy arrays and functions, and loops. The most common way to use Numba is through its collection of decorators that can be applied to your functions to instruct Numba to compile them. When a call is made to a Numba-decorated function it is compiled to machine code “just-in-time” for execution and all or part of your code can subsequently run at native machine code speed!

  • Pros:
  • 可充分利用cpu资源;通过cuda编程可利用gpu计算
  • 极大加速大矩阵的计算,矩阵越大提速越明显,提速一到两个数量级
  • 对于循环可实现平行计算
  • Cons:
  • 排序略比numpy要慢
  • 小规模计算提升不大
  • 包依赖和数据类型有严格限制(参见注意)

二、快速上手👋:

官方入门教程

https://numba.readthedocs.io/en/stable/user/5minguide.html

安装


pip install numba

conda install numba

简单例子:

import numpy as np
from numba import njit

"""
numpy实现
"""

compare_1w = np.random.rand(10000, 512)
base_10w = np.random.rand(512, 10000)
distance = np.dot(compare_1w, base_10w)
"""
numba实现
"""

def cal_dot(a, b):
    return np.dot(a, b)
distance = cal_dot(compare_1w, base_10w)

"""结果:速度提升约30倍
numpy cost: 0:00:45.465711
numba cost: 0:00:01.553018
"""

三、注意:各种报错及解决办法🔧

  1. 当装饰 njit(),函数不支持传入 dict,function
  2. 解决办法:可以先用numba得到原始的数据,再在函数外对numba的结果进行处理,这时就可以使用dict和function作为传入参数了
  3. 当装饰 njit(),函数不支持 continue,break,try,except
  4. 解决办法:优化流程,仅使用 for,if,else
  5. 当装饰 njit()
  6. 现象:若声明变量 a = [(),[],[]]*10
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function setitem>) found for signature:
 >>> setitem(list(Tuple())<iv=None>, int64, Tuple(int64, array(int64, 1d, A), list(float32)<iv=None>))
  • 解决办法:想办法把数据结构改成数组,或分成多个数组存放。数组需要事先声明大小
  • 当装饰 njit()String.format()
  • 现象
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'format' of type Literal[str]([*] Calculate img_type:{} similarity, batch {}/10, progress: {}/{})
  • 解决办法:去掉format,改用字符串拼接
  • 当装饰 njit(paralle=True)List.append()
  • 现象:append线程不安全
File "/******/python3.9/site-packages/numba/cpython/listobj.py", line 1129, in list_to_list
    assert fromty.dtype == toty.dtype
AssertionError
  • 解决办法:移除被装饰函数中的append(),改用index访问,预先声明list/np.ndarray的大小
  • 当装饰 njit(paralle=True)
  • 现象:
/*******/python3.9/site-packages/numba/np/ufunc/parallel.py:365: NumbaWarning: The TBB threading layer requires TBB version 2019.5 or later i.e., TBB_INTERFACE_VERSION >= 11005. Found TBB_INTERFACE_VERSION = 9107.

The TBB threading layer is disabled. warnings.warn(problem)
  • 解决办法:https://github.com/numba/numba/issues/6350
conda install tbb

pip install --upgrade tbb

四、Appendix: 速度测试代码⏱️

import numpy as np
from numba import jit, njit
from datetime import datetime

@jit(nopython=True)
def cal(a, b):
    c = np.dot(b, a)
    d = np.max(c)

@njit(parallel=True)
def gen(n):
    a = np.random.rand(512, 10**n)
    b = np.random.rand(100,512)
    return a, b
def numba_cal(a, b):
    start = datetime.now()
    cal(a, b)
    end = datetime.now()
    return end - start
def np_cal(a, b):
    start = datetime.now()
    np.dot(b, a)
    end = datetime.now()
    return end - start
for i in range(8):
    a, b = gen(i)
    delta = numba_cal(a, b)
    print("scale {}: numba cost: {}".format(10**i, delta))
    delta = np_cal(a, b)
    print("scale {}: np cost: {}".format(10**i, delta))

Original: https://blog.csdn.net/paperplaneY/article/details/119582920
Author: paperplaneY
Title: 【Numba】加速计算

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/761844/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球