当前位置:网站首页>小黑ai4code代码baseline啃食1
小黑ai4code代码baseline啃食1
2022-06-24 07:01:00 【小黑无敌】
比赛链接:https://www.kaggle.com/competitions/AI4Code
源代码链接:https://www.kaggle.com/code/ryanholbrook/getting-started-with-ai4code/notebook
导包
import json
from pathlib import Path
import numpy as np
import pandas as pd
from scipy import sparse
from tqdm import tqdm
pd.options.display.width = 180
pd.options.display.max_colwidth = 120
data_dir = Path('../input/AI4Code')
训练cell数据读入
NUM_TRAIN = 10000
paths_train = list((data_dir / 'train').glob('*.json'))[:NUM_TRAIN]
def read_notebook(path):
return (
pd.read_json(
path,
dtype = {
'cell_type':'category','source':'str'}
).assign(id = path.stem)
.rename_axis('cell_id')
)
notebooks_train = [
read_notebook(path) for path in tqdm(paths_train,desc = 'Train NBs')
]
建立索引(id,cell_id)
# swaplevel:交换索引列的顺序
# sort_index:根据索引进行排序,sort_remaining表示按照排序索引列进行排序后(遇到相同索引列时是否继续细粒度对其他索引列排序)
df = (
pd.concat(notebooks_train).set_index('id',append = True).swaplevel().sort_index(level = 'id',sort_remaining = False)
)
df

# 展示id为nb_id的notebook的所有cell
nb_id = df.index.unique('id')[6]
print('Notebook:', nb_id)
print("The disordered notebook:")
nb = df.loc[nb_id, :]
display(nb)
print()

获取训练集的cell序列数据
# 获取序列
df_orders = pd.read_csv(
data_dir / 'train_orders.csv',
index_col = 'id',
squeeze = True
)
# str表示把pandas中的data当成字符串处理(split)
df_orders = pd.read_csv(
data_dir / 'train_orders.csv',
index_col='id',
squeeze=True,
).str.split() # Split the string representation of cell_ids into a list
df_orders
# 得到标签序列
cell_order = df_orders.loc[nb_id]
print('The ordered notebook:')
# 按照正确的cell id顺序显示
nb.loc[cell_order,:]

def get_ranks(base,derived):
return [base.index(d) for d in derived]
cell_ranks = get_ranks(cell_order,list(nb.index))
nb.insert(0,'rank',cell_ranks)
nb

from pandas.testing import assert_frame_equal
# 确认 rank的结果与正确序列相一致
assert_frame_equal(nb.loc[cell_order,:],nb.sort_values('rank'))
df_orders_ = df_orders.to_frame().join(
df.reset_index('cell_id').groupby('id')['cell_id'].apply(list),how = 'right'
)
ranks = {
}
for id_,cell_order,cell_id in df_orders_.itertuples():
ranks[id_] = {
'cell_id':cell_id,'rank':get_ranks(cell_order,cell_id)}
df_ranks = (
pd.DataFrame.from_dict(ranks,orient = 'index')
.rename_axis('id')
.apply(pd.Series.explode)
.set_index('cell_id',append = True)
)
df_ranks

df_ancestors = pd.read_csv(data_dir / 'train_ancestors.csv',index_col = 'id')
df_ancestors

数据集划分
from sklearn.model_selection import GroupShuffleSplit
NVALID = 0.1
splitter = GroupShuffleSplit(n_splits = 1,test_size = NVALID,random_state = 0)
ids = df.index.unique('id')
ancestors = df_ancestors.loc[ids,'ancestor_id']
ids_train,ids_valid = next(splitter.split(ids,groups = ancestors))
ids_train, ids_valid = ids[ids_train], ids[ids_valid]
df_train = df.loc[ids_train,:]
df_valid = df.loc[ids_valid,:]
训练集特征构建
from sklearn.feature_extraction.text import TfidfVectorizer
# 训练集
tfidf = TfidfVectorizer(min_df = 0.01)
X_train = tfidf.fit_transform(df_train['source'].astype(str))
y_train = df_ranks.loc[ids_train].to_numpy()
groups = df_ranks.loc[ids_train].groupby('id').size().to_numpy()
# Add code cell ordering
X_train = sparse.hstack((
X_train,
np.where(
df_train['cell_type'] == 'code',
df_train.groupby(['id', 'cell_type']).cumcount().to_numpy() + 1,
0,
).reshape(-1, 1)
))
print(X_train.shape)
(416586, 284)
模型训练
from xgboost import XGBRanker
model = XGBRanker(
min_child_weight = 10,
subsample = 0.5,
tree_method = 'hist'
)
model.fit(X_train,y_train,group = groups)

模型预测
X_valid = tfidf.transform(df_valid['source'].astype(str))
y_valid = df_orders.loc[ids_valid]
X_valid = sparse.hstack((
X_valid,
np.where(
df_valid['cell_type'] == 'code',
df_valid.groupby(['id','cell_type']).cumcount().to_numpy() + 1,
0
).reshape(-1,1)
))
y_pred = pd.DataFrame(
{
'rank':model.predict(X_valid)
},
index = df_valid.index
)
y_pred = (
y_pred
.sort_values(['id','rank'])
.reset_index('cell_id')
.groupby('id')['cell_id'].apply(list)
)
模型评估

from bisect import bisect
def count_inversions(a):
inversions = 0
sorted_so_far = []
for i,u in enumerate(a):
j = bisect(sorted_so_far,u)
inversions += i - j
sorted_so_far.insert(j,u)
return inversions
def kendall_tau(ground_truth,predictions):
total_inversions = 0
total_2max = 0
for gt,pred in zip(ground_truth,predictions):
ranks = [gt.index(x) for x in pred]
total_inversions += count_inversions(ranks)
n = len(gt)
total_2max += n * (n - 1)
return 1 - 4 * total_inversions / total_2max
y_dummy = df_valid.reset_index('cell_id').groupby('id')['cell_id'].apply(list)
print('(y_valid, y_dummy):',kendall_tau(y_valid, y_dummy))
print('(y_valid, y_pred):',kendall_tau(y_valid, y_pred))
(y_valid, y_dummy): 0.42511216883092573
(y_valid, y_pred): 0.6158894721044015
提交
paths_test = list((data_dir / 'test').glob('*.json'))
notebooks_test = [
read_notebook(path) for path in tqdm(paths_test,desc = 'Test NBs')
]
df_test = (
pd.concat(notebooks_test)
.set_index('id',append = True)
.swaplevel()
.sort_index(level = 'id',sort_remaining = False)
)
X_test = tfidf.transform(df_test['source'].astype(str))
X_test = sparse.hstack((
X_test,
np.where(
df_test['cell_type'] == 'code',
df_test.groupby(['id','cell_type']).cumcount().to_numpy() + 1,
0,
).reshape(-1,1)
))
y_infer = pd.DataFrame({
'rank':model.predict(X_test)},index = df_test.index)
y_infer = y_infer.sort_values(['id','rank']).reset_index('cell_id').groupby('id')['cell_id'].apply(list)
y_sample = pd.read_csv(data_dir / 'sample_submission.csv',index_col = 'id',squeeze = True)
y_submit = (
y_infer
.apply(' '.join)
.rename_axis('id')
.rename('cell_order')
)
y_submit.to_csv('submission.csv')
y_submit.to_frame()
边栏推荐
猜你喜欢

2022年制冷与空调设备运行操作上岗证题库及模拟考试

ZUCC_编译语言原理与编译_实验04 语言与文法

12-- merge two ordered linked lists

Understanding of the concept of "quality"

【无标题】

2022 mobile crane driver special operation certificate examination question bank and online simulation examination

ZUCC_编译语言原理与编译_实验06 07 语法分析 LL 分析

Maya re deployment

jwt(json web token)

LabVIEW查找n个元素数组中的质数
随机推荐
Several ways you can't move zero (sequel)
Cloudbase database migration scheme
Industrial computer anti cracking
问题4 — DatePicker日期选择器,2个日期选择器(开始、结束日期)的禁用
ZUCC_编译语言原理与编译_实验02 FSharp OCaml语言
[real estate opening online house selection, WiFi coverage temporary network] 500 people are connected to WiFi at the same time
[graduation season] Hello stranger, this is a pink letter
复习SGI STL二级空间配置器(内存池) | 笔记自用
Common misconceptions in Tencent conference API - signature error_ code 200003
Take my brother to do the project. It's cold
Which is the first poem of Tang Dynasty?
1279_ Vsock installation failure resolution when VMware player installs VMware Tools
OpenCV to realize the basic transformation of image
LabVIEW finds prime numbers in an array of n elements
jwt(json web token)
05-ubuntu安装mysql8
普通token
[untitled]
Synthesize video through ffmpeg according to m3u8 file of video on the network
ZUCC_编译语言原理与编译_大作业