当前位置:网站首页>用K-means聚类分类不同行业的关税模型
用K-means聚类分类不同行业的关税模型
2022-07-28 09:58:00 【interval_package】
这里的话是对不同行业做一个聚类
在之前是做过了VAR模型的拟合,然后用VAR的方差分解,Fevd来作为这个学习参数
跟大家分享一下
有借鉴了sklearn官网的代码
有什么想要交流的欢迎评论
import matplotlib.pyplot as plt
from VARModelFitting import *
from sklearn.cluster import MiniBatchKMeans, KMeans
from sklearn.metrics.pairwise import pairwise_distances_argmin
import scipy.interpolate as spi
import time
class MyIndustry(object):
def __init__(self, Data: pd.DataFrame, name, lag=3):
self.name = name
self.rawDecompInfo = self.fevdIndentity(Data, lag, 10)
self.DataJointing()
# print(name,':\n',self.rawDecompInfo)
pass
@staticmethod
def fevdIndentity(Data, lag, MaxPeriod):
return VARFitter(Data, lag)[1].fevd(MaxPeriod).decomp
def DataJointing(self, sepecificFevd=-1, enhanceFlag=True, enhanceCount=1000):
if sepecificFevd < 0:
output = []
for item in self.rawDecompInfo:
if isinstance(output, np.ndarray):
output = np.concatenate((output, item), axis=0)
else:
output = item
else:
output = self.rawDecompInfo[sepecificFevd]
print("output:\n", output)
if enhanceFlag:
output = self.DataEnhance(output, enhanceCount)
print("enhanceFlag output:\n", output)
return output
@staticmethod
def DataEnhance(data: np.ndarray, _min=-1, _max=-1, enhanceCount=50):
if _min < 0 or _max < 0:
_min = 0
_max = data.shape[0]
output = []
rawData = data.T
periods = np.arange(_min, _max)
amplifiedPeriods = np.linspace(_min, _max, enhanceCount)
for row in rawData:
tck = spi.splrep(periods, row)
result = spi.splev(amplifiedPeriods, tck, der=0)
result = result.reshape(result.shape[0], 1)
if isinstance(output, np.ndarray):
output = np.hstack((output, result))
else:
output = result
return output
def industryIdentityDefine(self, k_means_cluster_centers, mbk_means_cluster_centers, n_clusters, sepecificFevd=0):
inputs = self.DataJointing(sepecificFevd=sepecificFevd, enhanceFlag=False)
# 这里就是使用pairwise_distances_argmin,来进行簇分类了,很妙
k_means_labels = pairwise_distances_argmin(inputs, k_means_cluster_centers)
mbk_means_labels = pairwise_distances_argmin(inputs, mbk_means_cluster_centers)
k_means_result = np.zeros((n_clusters, 1))
mbk_means_result = np.zeros((n_clusters, 1))
for i in range(n_clusters):
k_means_result[i] = sum(k_means_labels == i)
mbk_means_result[i] = sum(mbk_means_labels == i)
return k_means_result, mbk_means_result
def industryIdentityShow(k_means_cluster_centers, mbk_means_cluster_centers, n_clusters, sepecificFevd=0):
plt.figure("industryIdentityShow", figsize=(20, 16))
Data_base, classNames = GetClasses()
for name, idx in zip(classNames, range(0, len(classNames))):
Data = ReadTariffData(Data_base, name)
Data = ProcessBaseData(Data)
name = name.strip()
try:
k_means_result, mbk_means_result = \
MyIndustry(Data, name).industryIdentityDefine(k_means_cluster_centers, mbk_means_cluster_centers,
n_clusters, sepecificFevd)
# 这里数据结构化有点毛病
plt.subplot(5, 5, idx + 1), plt.title(name)
plt.bar(np.arange(k_means_result.shape[0]), k_means_result.T[0])
# plt.subplot(5, 10, 2 * idx + 2), plt.title(name)
# plt.bar(np.arange(k_means_result.shape[0]), mbk_means_result.T[0]), plt.axis('off')
except np.linalg.LinAlgError as e:
print(name + ": fail the var model LinAlgError: ", e)
print(repr(e))
plt.show()
pass
def clusteringPeriod(inputs, batch_size=45, n_clusters=5):
# #############################################################################
# k means
k_means = KMeans(init="k-means++", n_clusters=n_clusters, n_init=10)
t0 = time.time()
k_means.fit(inputs)
t_batch = time.time() - t0
print('cluster_centers of k_means:\n', k_means.cluster_centers_)
# #############################################################################
# Compute clustering with MiniBatchKMeans
mbk = MiniBatchKMeans(
init="k-means++",
n_clusters=n_clusters,
batch_size=batch_size,
n_init=10,
max_no_improvement=10,
verbose=0,
)
t0 = time.time()
mbk.fit(inputs)
t_mini_batch = time.time() - t0
print('cluster_centers of MiniBatchKMeans:\n', mbk.cluster_centers_)
return k_means, t_batch, mbk, t_mini_batch
def clusteringComparingPloting(inputs, batch_size=45, n_clusters=5):
k_means, t_batch, mbk, t_mini_batch = clusteringPeriod(inputs, batch_size, n_clusters)
n_clusters = k_means.n_clusters
# #############################################################################
# Plot result
fig = plt.figure(figsize=(8, 3))
fig.subplots_adjust(left=0.02, right=0.98, bottom=0.05, top=0.9)
colors = ["#4EACC5", "#FF9C34", "#4E9A06"]
# We want to have the same colors for the same cluster from the
# MiniBatchKMeans and the KMeans algorithm. Let's pair the cluster centers per
# closest one.
k_means_cluster_centers = k_means.cluster_centers_
order = pairwise_distances_argmin(k_means.cluster_centers_, mbk.cluster_centers_)
mbk_means_cluster_centers = mbk.cluster_centers_[order]
# 这里就是使用pairwise_distances_argmin,来进行簇分类了,很妙
k_means_labels = pairwise_distances_argmin(inputs, k_means_cluster_centers)
mbk_means_labels = pairwise_distances_argmin(inputs, mbk_means_cluster_centers)
# KMeans
ax = fig.add_subplot(1, 3, 1)
for k, col in zip(range(n_clusters), colors):
my_members = k_means_labels == k
cluster_center = k_means_cluster_centers[k]
ax.plot(inputs[my_members, 0], inputs[my_members, 1], "w", markerfacecolor=col, marker=".")
ax.plot(
cluster_center[0],
cluster_center[1],
"o",
markerfacecolor=col,
markeredgecolor="k",
markersize=6,
)
ax.set_title("KMeans")
ax.set_xticks(())
ax.set_yticks(())
plt.text(-3.5, 1.8, "train time: %.2fs\ninertia: %f" % (t_batch, k_means.inertia_))
# MiniBatchKMeans
ax = fig.add_subplot(1, 3, 2)
for k, col in zip(range(n_clusters), colors):
my_members = mbk_means_labels == k
cluster_center = mbk_means_cluster_centers[k]
ax.plot(inputs[my_members, 0], inputs[my_members, 1], "w", markerfacecolor=col, marker=".")
ax.plot(
cluster_center[0],
cluster_center[1],
"o",
markerfacecolor=col,
markeredgecolor="k",
markersize=6,
)
ax.set_title("MiniBatchKMeans")
ax.set_xticks(())
ax.set_yticks(())
plt.text(-3.5, 1.8, "train time: %.2fs\ninertia: %f" % (t_mini_batch, mbk.inertia_))
# Initialise the different array to all False
different = mbk_means_labels == 4
ax = fig.add_subplot(1, 3, 3)
for k in range(n_clusters):
different += (k_means_labels == k) != (mbk_means_labels == k)
identic = np.logical_not(different)
ax.plot(inputs[identic, 0], inputs[identic, 1], "w", markerfacecolor="#bbbbbb", marker=".")
ax.plot(inputs[different, 0], inputs[different, 1], "w", markerfacecolor="m", marker=".")
ax.set_title("Difference")
ax.set_xticks(())
ax.set_yticks(())
plt.show()
return k_means_cluster_centers, mbk_means_cluster_centers, n_clusters
# 总之就是输出一个矩阵,矩阵的每一行都是我们进行学习的一个元素
def unpackData(unpackSepecificItem=-1, enhanceFlag=True):
Data_base, classNames = GetClasses()
output = []
for name in classNames:
Data = ReadTariffData(Data_base, name)
Data = ProcessBaseData(Data)
name = name.strip()
try:
if isinstance(output, np.ndarray):
output = np.concatenate((output, MyIndustry(Data, name).DataJointing(unpackSepecificItem, enhanceFlag)),
axis=0)
else:
output = MyIndustry(Data, name).DataJointing(unpackSepecificItem, enhanceFlag)
except TypeError as e:
print(name + ": fail the var model, TypeError")
print(repr(e))
# warnings.warn(name+": fail the var model")
continue
except np.linalg.LinAlgError as e:
print(name + ": fail the var model LinAlgError: ", e)
print(repr(e))
except Exception as e:
print(name + ": fail the var model, unknown: ", e)
print(repr(e))
# warnings.warn(name+": fail the var model")
continue
print(output.shape)
return output
def main():
data = unpackData(unpackSepecificItem=1, enhanceFlag=True)
k_means_cluster_centers, mbk_means_cluster_centers, n_clusters = clusteringComparingPloting(data, n_clusters=5)
industryIdentityShow(k_means_cluster_centers, mbk_means_cluster_centers, n_clusters, sepecificFevd=0)
pass
if __name__ == '__main__':
main()
边栏推荐
- In the era of home health diagnosis, Senzo creates enhanced lateral flow test products
- Joint search set
- [openharmony] [rk2206] build openharmony compiler (2)
- In retaliation for the dismissal of the company, I changed all code comments of the project!
- 排序——快速排序(快慢指针实现)
- MySQL架构原理
- On July 13, 2021, we collapsed like this
- 并查集
- Introduction to thresholdfilter
- ES (8.1) certification topic
猜你喜欢
随机推荐
博弈论 1.Introduction(组合游戏基本概念、对抗搜索、Bash游戏、Nim游戏)
Edge team explains how to improve the comprehensive performance experience through disk cache compression technology
16、字符串反转
【JZOF】15二进制中1的位数
头文件库文件
LinkedList源码按摩,啊舒服
centos7下安装mysql,网上文章都不太准
SuperMap iServer发布管理以及调用地图服务
Being on duty less than 8 hours a day and being dismissed? Tencent's former employees recovered 13million overtime pay, etc., and the court won a compensation of 90000 in the final judgment
Detailed explanation of thread synchronization volatile and synchronized
房地产数字化转型方案:全方位数智化系统运营,助力房企管控实效提升
海量数据TopN问题
二维前缀和
Voice chat app - how to standardize the development process?
Prometheus operation and maintenance tool promtool (IV) TSDB function
LSA and optimization of OSPF
[esp32][esp idf] ap+sta realizes wireless bridging and transferring WiFi signals
Uni app advanced creation component / native rendering
记录一次idea中的父子项目修改project与module名称,亲测!
Boss: there are too many systems in the company. Can we realize account interworking?


![[openharmony] [rk2206] build openharmony compiler (2)](/img/0c/2e8290403d64ec43d192969f776724.png)






