正弦函数预测
通过已知的样本数据对正弦函数进行预测,并绘制出图形。设计 LSTM网络进行预测,记录预测准确率并绘制图形。
1.准备数据
import torch
import torchvision.datasets as datasets
import torch.nn as nn
import torch. nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
import matplotlib.pyplot as plt
import numpy as np
import time
def getSeq(start, n):
x = [np.sin(x / 10) for x in range(start, start + n)]
return x
def gen_data():
data = []
for i in range(200):
data.append(getSeq(i, 11))
data = torch.from_numpy(np.array(data))
data = data[:, : -1].type(torch.FloatTensor)
target = data[:, -1 :].type(torch.FloatTensor)
train_x = data[: 150]
train_y = target[: 150]
test_x = data[150 :]
test_y = target[150 :]
train_dataset = TensorDataset(train_x, train_y)
test_dataset = TensorDataset(test_x, test_y)
train_loader = DataLoader(dataset=train_dataset, batch_size=5, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=5, shuffle=False)
return train_y, train_loader, test_loader
2.构建模型
class LSTM(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(1, 10, batch_first=True)
self.fc = nn.Linear(10, 1)
def forward(self, x, hidden):
output, hidden = self.lstm(x, hidden)
output = output[:, -1, :]
output = self.fc(output)
return output
3.训练模型
def train(train_loader, lstm, loss_func, opt, epochs):
h0 = torch.zeros(1, 5, 10)
c0 = torch.zeros(1, 5, 10)
train_loss_list = []
for epoch in range(epochs):
for i, data in enumerate(train_loader):
x, y = data
x = x.view(-1, 10, 1)
pred = lstm(x, (h0, c0))
loss = loss_func(pred, y)
opt.zero_grad()
loss.backward()
opt.step()
train_loss_list.append(loss.detach().numpy())
if epoch % 100 == 0:
print("LSTM_loss: ", loss)
plt.plot(np.arange(epochs), train_loss_list, 'r', label='train')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.legend()
plt.show()
lstm = LSTM()
loss_func = nn.MSELoss()
opt = torch.optim.Adam(lstm.parameters(), lr=0.001)
epochs = 500
train_y, train_loader, test_loader = gen_data()
train(train_loader, lstm, loss_func, opt, epochs)
4.预测结果
def prediction(test_loader, lstm):
h0 = torch.zeros(1, 5, 10)
c0 = torch.zeros(1, 5, 10)
preds = []
for i, data in enumerate(test_loader):
x, y = data
x = x.view(-1, 10, 1)
pred = lstm(x, (h0, c0))
preds.append(pred.detach().numpy())
# print(y.view(1, -1))
# print(pred.view(1, -1))
true = plt.scatter(range(len(train_y)), train_y.detach().numpy(), marker='o')
lstm_pred = plt.scatter(range(150, 200), preds, marker='s')
plt.legend((true, lstm_pred),('true','lstm_pred') ,loc = 'best')
plt.xlabel('x')
plt.ylabel('y')
plt.show()
prediction(test_loader, lstm)