我需要求解一些常微分方程 $\frac{d y}{d x} = f(x) = x^2 ln(x)$ 并继续在限制 0 之间创建数组 xpt。 <= xpt <= 2因为我必须小心 xpt = 0,所以我将函数定义如下
def f(x):
if x <= 1.e-6:
return 0.
else:
return np.square(x)*np.log(x)
我的调用程序读取
Npt = 200
xpt = np.linspace( 0.,2.,Npt)
fpt = np.zeros(Npt)
但是,当我拨打电话时 fpt = f(xpt) 我收到错误
“ValueError:具有多个元素的数组的真值不明确。使用 a.any() 或 a.all()”
我可以通过引入 for 循环来解决此问题,写
for ip in range(Npt):
fpt[ip] = f(xpt[ip])
但这似乎是一种黑客行为并且不能令人满意。
我尝试查看使用 a,any() 的建议并将该函数重新定义为
def Newf(x):
if ((x <= 1.e-6).all()):
return 0.
else:
return np.square(x*np.log(x))
但这似乎将 f(0.) 作为 nan。
任何有关如何继续的帮助都深表感谢。
遇到的问题是,正试图在
if
语句中对整个 NumPy 数组执行比较。
x < 1.e-6
返回一个布尔数组,的
if
语句不知道如何处理它。
无需使用循环或
any()
/
all()
即可实现此目标。NumPy 提供了一种称为
向量化
的功能,允许对整个数组执行操作。
以下是修复后的代码:
import numpy as np
def f(x):
# 使用 np.where 创建一个掩码,以有条件地应用条件
return np.where(x <= 1.e-6, 0, np.square(x) * np.log(x))
Npt = 200
xpt = np.linspace(0., 2., Npt)
fpt = f(xpt) # 现在可以直接将整个数组传递给 f
print(fpt)
这个版本的代码做了什么:
-
np.where(条件,值 if True,值 if False)
:此函数允许我们在数组中应用基于条件的逻辑。 -
x <= 1.e-6
:这为数组x
中的每个元素创建一个布尔数组,其中True
表示值小于或等于1.e-6
。 -
0
:如果条件为True
(即x
中的值小于或等于1.e-6
),则返回此值。 -
np.square(x) * np.log(x)
:如果条件为False
,则返回此值。
通过使用
np.where
,避免了需要显式循环,并且代码将更有效地运行,因为 NumPy 的向量化操作针对速度进行了优化。