基于Pytorch,使用改良的Transformer模型应用于多维时间序列的分类任务上
对比模型选择 Fully Convolutional Networks (FCN) and Residual Net works (ResNet)
DataSet | MLP | FCN | ResNet | Encoder | MCNN | t-LeNet | MCDCNN | Time-CNN | TWIESN | Gated Transformer |
---|---|---|---|---|---|---|---|---|---|---|
ArabicDigits | 96.9 | 99.4 | 99.6 | 98.1 | 10.0 | 10.0 | 95.9 | 95.8 | 85.3 | 98.8 |
AUSLAN | 93.3 | 97.5 | 97.4 | 93.8 | 1.1 | 1.1 | 85.4 | 72.6 | 72.4 | 97.5 |
CharacterTrajectories | 96.9 | 99.0 | 99.0 | 97.1 | 5.4 | 6.7 | 93.8 | 96.0 | 92.0 | 97.0 |
CMUsubject16 | 60.0 | 100 | 99.7 | 98.3 | 53.1 | 51.0 | 51.4 | 97.6 | 89.3 | 100 |
ECG | 74.8 | 87.2 | 86.7 | 87.2 | 67.0 | 67.0 | 50.0 | 84.1 | 73.7 | 91.0 |
JapaneseVowels | 97.6 | 99.3 | 99.2 | 97.6 | 9.2 | 23.8 | 94.4 | 95.6 | 96.5 | |
Libras | 78.0 | 96.4 | 95.4 | 78.3 | 6.7 | 6.7 | 65.1 | 63.7 | 79.4 | 88.9 |
UWave | 90.1 | 93.4 | 92.6 | 90.8 | 12.5 | 12.5 | 84.5 | 8.9 | 75.4 | 91.0 |
KickvsPunch | 61.0 | 54.0 | 51.0 | 61.0 | 54.0 | 50.0 | 56.0 | 62.0 | 67.0 | 90.0 |
NetFlow | 55.0 | 89.1 | 62.7 | 77.7 | 77.9 | 72.3 | 63.0 | 89.0 | 94.5 | 100 |
PEMS | - | - | - | - | - | - | - | - | - | 93.6 |
Wafer | 89.4 | 98.2 | 98.9 | 98.6 | 89.4 | 89.4 | 65.8 | 94.8 | 94.9 | 99.1 |
WalkvsRun | 70.0 | 100 | 100 | 100 | 75.0 | 60.0 | 45.0 | 100.0 | 94.4 | 100 |
环境 | 描述 |
---|---|
语言 | Python3.7 |
框架 | Pytorch1.6 |
IDE | Pycharm and Colab |
设备 | CPU and GPU |
多元时间序列数据集, 文件为.mat格式,训练集与测试集在一个文件中,且预先定义为了测试集数据,测试集标签,训练集数据与训练集标签。
数据集下载使用百度云盘,连接如下:
链接:https://pan.baidu.com/s/1u2HN6tfygcQvzuEK5XBa2A
提取码:dxq6
数据集维度描述 DataSet|Number of Classes|Size of training Set|Size of testing Set|Max Time series Length|Channel| -------|-----------------|--------------------|-------------------|----------------------|-------| ArabicDigits|10|6600|2200|93|13| AUSLAN|95|1140|1425|136|22| CharacterTrajectories|20|300|2558|205|3| CMUsubject16|2|29|29|580|62| ECG|2|100|100|152|2| JapaneseVowels|9|270|370|29|12| Libras|15|180|180|45|2| UWave|8|200|4278|315|3| KickvsPunch|2|16|10|841|62| NetFlow|2|803|534|997|4| PEMS|7|267|173|144|963| Wafer|2|298|896|198|6| WalkvsRun|2|28|16|1918|62|
详细数据集处理过程参看 dataset_process.py文件。
h = W · Concat(C, S) + b
g1, g2 = Softmax(h)
y = Concat(C · g1, S · g2)
超参 | 描述 |
---|---|
d_model | 模型处理的为时间序列而非自然语言,所以省略了NLP中对词语的编码,仅使用一个线性层映射成d_model维的稠密向量,此外,d_model保证了在每个模块衔接的地方的维度相同 |
d_hidden | Position-wise FeedForword 中隐藏层的维度 |
d_input | 时间序列长度,其实是一个数据集中最长时间步的维度 固定的,直接由数据集预处理决定 |
d_channel | 多元时间序列的时间通道数,即是几维的时间序列 固定的,直接由数据集预处理决定 |
d_output | 分类类别数 固定的,直接由数据集预处理决定 |
q,v | Multi-Head Attention中线性层映射维度 |
h | Multi-Head Attention中头的数量 |
N | Encoder栈中Encoder的数量 |
dropout | 随机失活 |
EPOCH | 训练迭代次数 |
BATCH_SIZE | mini-batch size |
LR | 学习率 定义为1e-4 |
optimizer_name | 优化器选择 建议Adagrad和Adam |
文件名称 | 描述 |
---|---|
dataset_process | 数据集处理 |
font | 存储字体,用于结果图中的文字 |
gather_figure | 聚类结果图 |
heatmap_figure_in_test | 测试模型时绘制的score矩阵的热力图 |
module | 模型的各个模块 |
mytest | 各种测试代码 |
reslut_figure | 准确率结果图 |
saved_model | 保存的pkl文件 |
utils | 工具类文件 |
run.py | 训练模型 |
run_with_saved_model.py | 使用训练好的模型(保存为pkl文件)测试结果 |
简单介绍几个
[Wang et al., 2017] Z. Wang, W. Yan, and T. Oates. Time series classification from scratch with deep neural networks:A strong baseline. In 2017 International Joint Conference on Neural Networks (IJCNN), pages 1578–1585, 2017.