当前位置:网站首页>Use the kaggle training model and download your own training model
Use the kaggle training model and download your own training model
2022-07-02 08:19:00 【Fuly1024】
kaggle Address : https://www.kaggle.com/
Upload data sets

Add data sets that others have uploaded

View data path
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
for filename in filenames:
print(os.path.join(dirname, filename))

- Add code
# -*- coding: utf-8 -*-
import datetime
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dropout, Dense, SimpleRNN
import matplotlib.pyplot as plt
import os
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error
import math
# normalization
sc = MinMaxScaler(feature_range=(0, 1)) # Define normalization : Normalize to (0,1) Between
def get_stock_data(file_path):
maotai = pd.read_csv(file_path)
training_set = maotai.iloc[0:2426 - 300, 2:3].values
test_set = maotai.iloc[2426 - 300:, 2:3].values
training_set_scaled = sc.fit_transform(training_set)
test_set_scaled = sc.transform(test_set)
x_train = []
y_train = []
for i in range(60, len(training_set_scaled)):
x_train.append(training_set_scaled[i - 60:i, 0])
y_train.append(training_set_scaled[i, 0])
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
x_train = np.array(x_train)
y_train = np.array(y_train)
x_train = np.reshape(x_train, (x_train.shape[0], 60, 1))
x_test = []
y_test = []
for i in range(60, len(test_set_scaled)):
x_test.append(test_set_scaled[i - 60:i, 0])
y_test.append(test_set_scaled[i, 0])
x_test = np.array(x_test)
y_test = np.array(y_test)
x_test = np.reshape(x_test, (x_test.shape[0], 60, 1))
return (x_train, y_train), (x_test, y_test)
def load_local_model(model_path):
if os.path.exists(model_path + '/saved_model.pb'):
print(datetime.datetime.now())
local_model = tf.keras.models.load_model(model_path)
else:
local_model = tf.keras.Sequential([
SimpleRNN(80, return_sequences=True),
Dropout(0.2),
SimpleRNN(100),
Dropout(0.2),
Dense(1)
])
local_model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
loss='mean_squared_error') # The loss function uses the mean square error
return local_model
def show_train_line(history):
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
def stock_predict(model, x_test, y_test):
# Test set input model for prediction
predicted_stock_price = model.predict(x_test)
# Restore forecast data --- from (0,1) Inverse normalization to the original range
predicted_stock_price = sc.inverse_transform(predicted_stock_price)
# Restore real data --- from (0,1) Inverse normalization to the original range
real_stock_price = sc.inverse_transform(np.reshape(y_test, (y_test.shape[0], 1)))
# Draw a comparison curve between real data and predicted data
plt.plot(real_stock_price, color='red', label='MaoTai Stock Price')
plt.plot(predicted_stock_price, color='blue', label='Predicted MaoTai Stock Price')
plt.title('MaoTai Stock Price Prediction')
plt.xlabel('Time')
plt.ylabel('MaoTai Stock Price')
plt.legend()
plt.show()
plt.savefig('./model/rnn/compare.jpg')
mse = mean_squared_error(predicted_stock_price, real_stock_price)
rmse = math.sqrt(mean_squared_error(predicted_stock_price, real_stock_price))
mae = mean_absolute_error(predicted_stock_price, real_stock_price)
print(' Mean square error : %.6f' % mse)
print(' Root mean square error : %.6f' % rmse)
print(' Mean absolute error : %.6f' % mae)
if __name__ == '__main__':
file_path = '/kaggle/input/databases/SH600519.csv'
(x_train, y_train), (x_test, y_test) = get_stock_data(file_path)
model_path = "./model/rnn"
model = load_local_model(model_path)
history = model.fit(x_train, y_train, batch_size=265, epochs=100, validation_data=(x_test, y_test),validation_freq=1)
show_train_line(history)
model.summary()
model.save(model_path, save_format="tf")
stock_predict(model, x_test, y_test)
choice gpu perhaps tpu

Save the model

Download the trained model
(1) preservation (File–>Save Version)
Save & Run All(Commit) Be sure to choose this
Advanced Settings–>Always save out perhaps Save output for this version
( Now speed up your choice )
Save it and wait for it to finish running
Find the item you just saved

Download complete
边栏推荐
- 力扣方法总结:滑动窗口
- Intelligent manufacturing solutions digital twin smart factory
- One of the reasons for WCF update service reference error
- Jupyter Notebook常用快捷键(在命令模式中按H也可查看)
- Use of opencv3 6.2 low pass filter
- CarSim problem failed to start solver: path_ ID_ OBJ(X) was set to Y; no corresponding value of XXXXX?
- OpenCV常用方法出处链接(持续更新)
- Force buckle method summary: sliding window
- Global and Chinese market of snow sweepers 2022-2028: Research Report on technology, participants, trends, market size and share
- How to wrap qstring strings
猜你喜欢

MySQL优化

Development of digital collection trading website development of metauniverse digital collection

Dynamic extensible representation for category incremental learning -- der

包图画法注意规范

2022 Heilongjiang latest food safety administrator simulation exam questions and answers

Command line is too long

Principes fondamentaux de la théorie musicale (brève introduction)

Carsim problem failed to start Solver: Path Id Obj (X) was set to y; Aucune valeur de correction de xxxxx?

Carsim-問題Failed to start Solver: PATH_ID_OBJ(X) was set to Y; no corresponding value of XXXXX?

SQLyog远程连接centos7系统下的MySQL数据库
随机推荐
Global and Chinese market of electric cheese grinder 2022-2028: Research Report on technology, participants, trends, market size and share
OpenCV常用方法出处链接(持续更新)
Introduction to anti interception technology of wechat domain name
SQL operation database syntax
Summary of one question per day: stack and queue (continuously updated)
How to apply for a secondary domain name?
用C# 语言实现MYSQL 真分页
Valin cable: BI application promotes enterprise digital transformation
Command line is too long
Animation synchronization of CarSim real-time simulation
程序猿学英语-Learning C
力扣每日一题刷题总结:链表篇(持续更新)
最长等比子序列
Library function of C language
How to uninstall SQL Server cleanly
OpenCV3 6.2 低通滤波器的使用
OpenCV 6.4 中值滤波器的使用
The internal network of the server can be accessed, but the external network cannot be accessed
使用Matplotlib绘制图表初步
MySQL优化