逻辑: 从
def generate_one_training_data(key, m=100, P=1):
N = 512
gp_params = (1.0, length_scale)
jitter = 1e-10
X = np.linspace(0, 1, N)[:,None]
K = RBF(X, X, gp_params)
L = np.linalg.cholesky(K + jitter*np.eye(N))
gp_sample = np.dot(L, random.normal(key, (N,)))
u_fn = lambda x, t: np.interp(t, X.flatten(), gp_sample)
x = np.linspace(0, 1, m)
u = vmap(u_fn, in_axes=(None,0))(0.0, x)
y_train = random.uniform(key, (P,)).sort()
s_train = odeint(u_fn, 0.0, np.hstack((0.0, y_train)))[1:]
u_train = np.tile(u, (P,1))
u_r_train = np.tile(u, (m, 1))
y_r_train = x
s_r_train = u
return u_train, y_train, s_train, u_r_train, y_r_train, s_r_train
到
def generate_training_data(key, N, m, P):
config.update("jax_enable_x64", True)
keys = random.split(key, N)
gen_fn = jit(lambda key: generate_one_training_data(key, m, P))
u_train, y_train, s_train, u_r_train, y_r_train, s_r_train = vmap(gen_fn)(keys)
u_train = np.float32(u_train.reshape(N * P,-1))
y_train = np.float32(y_train.reshape(N * P,-1))
s_train = np.float32(s_train.reshape(N * P,-1))
u_r_train = np.float32(u_r_train.reshape(N * m,-1))
y_r_train = np.float32(y_r_train.reshape(N * m,-1))
s_r_train = np.float32(s_r_train.reshape(N * m,-1))
config.update("jax_enable_x64", False)
return u_train, y_train, s_train, u_r_train, y_r_train, s_r_train
从这里可以看出,N中每个数据点都是不同的key,这就导致每个数据点的y_train都是不同的(每个数据点y_train都要重新随机)。
N_train = 10000
m = 100
P_train = 1
key_train = random.PRNGKey(0)
u_train, y_train, s_train, u_r_train, y_r_train, s_r_train = generate_training_data(key_train, N_train, m, P_train)
最后这个地方不同的key是为了确保验证集和训练集中没有相同的样本。
所有的数据生成后,弄到两个数据集(同属一类)实例里:
batch_size = 10000
operator_dataset = DataGenerator(u_train, y_train, s_train, batch_size=batch_size)
physics_dataset = DataGenerator(u_r_train, y_r_train, s_r_train, batch_size=batch_size)
batch_size=batch_size才是正确的。
这个batch_size决定了后面训练每个iter的容量。
Data generator类
class DataGenerator(data.Dataset):
def __init__(self, u, y, s,
batch_size=64, rng_key=random.PRNGKey(1234)):
'Initialization'
self.u = u
self.y = y
self.s = s
self.N = u.shape[0]
self.batch_size = batch_size
self.key = rng_key
def __getitem__(self, index):
'Generate one batch of data'
self.key, subkey = random.split(self.key)
inputs, outputs = self.__data_generation(subkey)
return inputs, outputs
@partial(jit, static_argnums=(0,))
def __data_generation(self, key):
'Generates data containing batch_size samples'
idx = random.choice(key, self.N, (self.batch_size,), replace=False)
s = self.s[idx,:]
y = self.y[idx,:]
u = self.u[idx,:]
inputs = (u, y)
outputs = s
return inputs, outputs
一个比较奇葩的事实是,__getitem__ 方法里面index参数并没有用。每个index都是从整个数据集中随机的sample一个batch。每个iter都是随机sample的。让我们看看训练代码:
class PI_DeepONet:
def __init__(self, branch_layers, trunk_layers):
self.branch_init, self.branch_apply = MLP(branch_layers, activation=np.tanh)
self.trunk_init, self.trunk_apply = MLP(trunk_layers, activation=np.tanh)
branch_params = self.branch_init(rng_key = random.PRNGKey(1234))
trunk_params = self.trunk_init(rng_key = random.PRNGKey(4321))
params = (branch_params, trunk_params)
self.opt_init, \
self.opt_update, \
self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3,
decay_steps=1000,
decay_rate=0.9))
self.opt_state = self.opt_init(params)
self.itercount = itertools.count()
self.loss_log = []
self.loss_operator_log = []
self.loss_physics_log = []
def operator_net(self, params, u, y):
branch_params, trunk_params = params
B = self.branch_apply(branch_params, u)
T = self.trunk_apply(trunk_params, y)
outputs = np.sum(B * T)
return outputs
def residual_net(self, params, u, y):
s_y = grad(self.operator_net, argnums = 2)(params, u, y)
return s_y
def loss_operator(self, params, batch):
inputs, outputs = batch
u, y = inputs
pred = vmap(self.operator_net, (None, 0, 0))(params, u, y)
loss = np.mean((outputs.flatten() - pred.flatten())**2)
return loss
def loss_physics(self, params, batch):
inputs, outputs = batch
u, y = inputs
pred = vmap(self.residual_net, (None, 0, 0))(params, u, y)
loss = np.mean((outputs.flatten() - pred.flatten())**2)
return loss
def loss(self, params, operator_batch, physics_batch):
loss_operator = self.loss_operator(params, operator_batch)
loss_physics = self.loss_physics(params, physics_batch)
loss = loss_operator + loss_physics
return loss
@partial(jit, static_argnums=(0,))
def step(self, i, opt_state, operator_batch, physics_batch):
params = self.get_params(opt_state)
g = grad(self.loss)(params, operator_batch, physics_batch)
return self.opt_update(i, g, opt_state)
def train(self, operator_dataset, physics_dataset, nIter = 10000):
operator_data = iter(operator_dataset)
physics_data = iter(physics_dataset)
pbar = trange(nIter)
for it in pbar:
operator_batch= next(operator_data)
physics_batch = next(physics_data)
self.opt_state = self.step(next(self.itercount), self.opt_state, operator_batch, physics_batch)
if it % 100 == 0:
params = self.get_params(self.opt_state)
loss_value = self.loss(params, operator_batch, physics_batch)
loss_operator_value = self.loss_operator(params, operator_batch)
loss_physics_value = self.loss_physics(params, physics_batch)
self.loss_log.append(loss_value)
self.loss_operator_log.append(loss_operator_value)
self.loss_physics_log.append(loss_physics_value)
pbar.set_postfix({'Loss': loss_value,
'loss_operator' : loss_operator_value,
'loss_physics': loss_physics_value})
@partial(jit, static_argnums=(0,))
def predict_s(self, params, U_star, Y_star):
s_pred = vmap(self.operator_net, (None, 0, 0))(params, U_star, Y_star)
return s_pred
@partial(jit, static_argnums=(0,))
def predict_s_y(self, params, U_star, Y_star):
s_y_pred = vmap(self.residual_net, (None, 0, 0))(params, U_star, Y_star)
return s_y_pred
由前面所说,下式事实上是无限迭代,train 方法中每个iter步采样一个physics和bc的batch:
operator_data = iter(operator_dataset)
physics_data = iter(physics_dataset)
训练时,每一个iter步分别取一个physics和bc数据集的batch。
|