jax方法(.at .set .add .multiply)
-
位置索引.at 设置新值.set:
import jax.numpy as jnp # 创建一个示例数组 a = jnp.array([1, 2, 3, 4, 5]) # 使用 .at 和 .set 方法将索引为 2 的元素设置为 10 b = a.at[2].set(10) print("Original array:", a) #Original array: [1 2 3 4 5] print("Updated array:", b) #Updated array: [ 1 2 10 4 5]
-
累加值.add:
# 使用 .at 和 .add 方法将索引为 2 的元素加上 10 c = a.at[2].add(10) print("Original array:", a) #Original array: [1 2 3 4 5] print("Updated array:", c) #Updated array: [ 1 2 13 4 5]
-
乘以值 .multiply:
# 使用 .at 和 .multiply 方法将索引为 2 的元素乘以 10 d = a.at[2].multiply(10) print("Original array:", a) #Original array: [1 2 3 4 5] print("Updated array:", d) #Updated array: [ 1 2 30 4 5]
-
多维数组的操作:
# 创建一个多维数组 a = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # 使用 .at 和 .set 方法将 (0, 1) 和 (2, 2) 元素设置为 10 b = a.at[(0, 1)].set(10) c = b.at[(2, 2)].set(10) print("Original array:\n", a) #Original array: [[1 2 3] [4 5 6] [7 8 9]] print("Updated array:\n", c) #Updated array: [[ 1 10 3] [ 4 5 6] [ 7 8 10]]