当前位置:网站首页>Wide & deep model and optimizer understand code practice
Wide & deep model and optimizer understand code practice
2022-06-24 04:05:00 【Goose】
1. background
wide & deep The model is Google stay 2016 A class of models for classification and regression published in . The model is applied to Google Play In the application recommendation of , Effectively increased Google Play The amount of software installed . at present wide & deep The model is open source , And in TensorFlow Advanced... Is available on API.
wide & deep The purpose of the model is to make the trained model give attention to memory at the same time (Memorization) And generalization (Generalization) Ability :
- Memorization: The model can learn the feature combination of high-frequency co-occurrence from historical data , Explore the correlation between features , Generating feature interactions through feature crossover “ memory ”, Efficient and interpretable . But generalize , More feature engineering is required .
- Generalization: The representative model can use the transitivity of correlation to explore feature combinations that have never appeared in historical data , adopt embedding Methods , Use low dimensional dense feature input , It can better generalize the cross features that never appear in the training samples .
The main contributions of this paper are as follows :
- Wide & Deep Joint training of feedforward neural network with embedding and linear model with feature transformation , For general recommendation systems with sparse input ;
- W&D stay google Online testing and evaluation on the app store ;
- stay TensorFlow API Contributed source code in , Convenient to call ;
The features used include :
- The characteristics of the user dimension ( City , Age , Demographic characteristics, etc )
- Contextual features ( equipment , A few requests , What day of the week )
- APP Dimensional features (app The length of time on line ,app Historical statistics of )
2. Model structure
2.1 Wide part
wide Part corresponds to the left part in the figure above , It is usually a generalized linear model LR:y=w*x+b.
- y: It's the result to predict
- x: It's a set of eigenvectors
- w: Parameters of the model
- b: Offset
The feature set contains Original input and their corresponding feature transformation , One of the more important transitions is :cross-product transformation Feature crossover ( Before the features intersect, each feature is required to one-hot), The corresponding formula is as follows :
c_{ki} Belong to 0,1
The above formula is actually one-hot code , For example, when gender=female,language=en When is 1, For the other 0.
2.2 Deep part
deep Part corresponds to the right part in the figure above , It's a feedforward neural network . For classification features , The original input is a string ( such as language=en). These are sparse 、 The first step of high dimension classification is to transform it into low dimension 、 Dense vectors . These vectors are usually in 10-100 Between dimensions , Generally, random initialization is used , In the process of training, the model is optimized by minimizing the loss function . These low dimensional vectors go to the hidden layer of the neural network . Each hidden layer is calculated as follows :
among :
- l: Layers of neural network
- f: Activation function ( Usually relu)
- a^l: The first l The output value of the layer
- b^l: The first l Layer offset
- W^l: The first l The weight of the layer
2.3 Wide & Deep Joint training
wide Part and deep The part uses the logarithm probability weighted sum of the output as the predicted value , And then it's fed into a logistic regression function for joint training . The paper emphasizes Joint training (Join training) And overall training (ensemble) The difference between .
- Ensemble: The two models are trained independently , The two models are combined only in the final forecast ; A single model needs to be bigger ( For example, feature transformation ) To ensure the accuracy of the combination
- Join trainging: During the training , Consider at the same time wide Part and deep Part and the weight of the two models to optimize all the parameters at the same time ;wide Part of it can be made up by a small amount of feature crossing deep Part of the weakness
wide & deep Of join training We use the gradient descent algorithm for the next batch (min-batch stochastic optimization) Optimized . In the experiment ,wide Part of it is Follow-the-regularized-leader(FTRL)+L1,deep Part of it is Adga.
L1 FTRL Will make Wide Most of the weights of the section are 0, We don't have to prepare so much when we prepare the features 0 The characteristics of weight , This greatly compresses the model weight , It also compresses the dimension of the eigenvector .
Deep Part of the input , Or Age,#App Installs These numerical class characteristics , Or it has been reduced and densified Embedding vector , Engineers do not and dare not directly input overly sparse feature vectors into Deep In the network . therefore Deep There is no serious feature sparsity problem in some parts , Naturally, it can be used with better accuracy , It is more suitable for deep learning and training AdaGrad To train .
about LR, The prediction results of the model are as follows :
among :
- Y:label
- \sigma(): Express sigmoid function
- \phi(x): Original features x Cross conversion of
- b: bias
- w_{wide}:wide Model weight
- W_{deep}:a^{(l_f)} The weight of
3. practice
This is the architecture diagram in the original paper , We may not fully abide by it in our own practice . For example, in the architecture diagram Wide Some only use cross features , When we use it, we can add the original discrete features or the scattered continuous features .
3.1 Feature Engineering and processing
- User characteristics : Length of registration 、 Basic features such as the length of time since the last visit , lately 3/7/15/30/90 The sky is active / Browse / Focus on /im Quantity and other behavioral characteristics , As well as portrait preference features and conversion rate features ;
- item features :item Basic features , And heat value / Continuous features such as click through rate ;
- Cross features : Compare portrait preferences with item The characteristics of .
- Missing value and abnormal value handling : Normal operation ; Different features use different missing value filling methods ; Outliers use quartiles ;
- Equal frequency bucket processing : Normal operation ; Like the price , Is a long tailed distribution , As a result, the eigenvalues of most samples are concentrated in a small value range , The discrimination of sample features is reduced .
- normalization : Normal operation ; The effect has been significantly improved ;
- Low frequency filtering : Normal operation ; For discrete features , Those that are too low frequency fall into one category ;
- embedding;
3.2 Offline training
- Data segmentation : use 7 Days of data as a training set ,1 Days as a test set
- embedding:
- Model tuning :
- Prevent over fitting : Join in dropOut And L2 Regular
- Speed up convergence : Introduced Batch Normalization
- Ensure training stability and convergence : Try different learning rate(wide Side 0.001,deep Side 0.01 It is better to ) and batch_size( The current setting is 2048)
- Optimizer : Contrast SGD、Adam、Adagrad Equal learner
A point of attention is mentioned in the paper : If you retrain every time , It will take a lot of time and energy , To solve this problem , The scheme adopted is hot start , That is, every time new training data is generated , Read from the previous model embedding And the weight of the linear model to initialize the new model , Use the previous model to verify before accessing the real-time stream , Make sure nothing goes wrong .
4. expand
Sometimes there are some problems for users or objects to be recommended Text and Image, To increase the effect , Multi modal features may be used .
Text and Image Of embedding vector , use and Wide Add to the overall model in the same way as the model .
A few simple ideas .
- Text and Image Of embedding vector , use and Wide Just add it to the overall model in the same way as the model . as for Of the two Embedding How to get a vector , It's up to you .
- Text and Image Used between attention Then add
- Text and Image and Deep After the output of the model is spliced, do the processing again
- Paper key word :Multimodal Fusion
5. Code example
train_data = "./../data/adult/adult.train" test_data = "./../data/adult/adult.test" train = pd.read_csv(train_data, sep=",", names=["age", "workclass", "fnlwgt", "education", "education_num", "marital_status", "occupation", "relationship", "race","sex", "capital_gain", "capital_loss", "hours_per_week", "native_country", "label"]) print(train.head(5))
age workclass fnlwgt ... hours_per_week native_country label 0 39 State-gov 77516 ... 40 United-States <=50K 1 50 Self-emp-not-inc 83311 ... 13 United-States <=50K 2 38 Private 215646 ... 40 United-States <=50K 3 53 Private 234721 ... 40 United-States <=50K 4 28 Private 338409 ... 40 Cuba <=50K
Define basic characteristics 、 Continuous features and dnn Features used :
# Define features that are basically continuous ,linear and dnn Will be used
age = tf.feature_column.numeric_column("age")
education_num = tf.feature_column.numeric_column("education_num")
capital_gain = tf.feature_column.numeric_column("capital_gain")
capital_loss = tf.feature_column.numeric_column("capital_loss")
hours_per_week = tf.feature_column.numeric_column("hours_per_week")
# Define discrete features
workclass = tf.feature_column.categorical_column_with_vocabulary_list(
key="workclass",
vocabulary_list=["Private", "Self-emp-not-inc", "Self-emp-inc", "Federal-gov", "Local-gov", "State-gov",
"Without-pay", "Never-worked", "?"]
)
education = tf.feature_column.categorical_column_with_vocabulary_list(
key="education",
vocabulary_list=["Bachelors", "Some-college", "11th", "HS-grad", "Prof-school", "Assoc-acdm", "Assoc-voc", "9th",
"7th-8th", "12th", "Masters", "1st-4th", "10th", "Doctorate", "5th-6th", "Preschool"]
)
marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
key="marital_status",
vocabulary_list=["Married-civ-spouse", "Divorced", "Never-married", "Separated", "Widowed", "Married-spouse-absent",
"Married-AF-spouse"]
)
relationship = tf.feature_column.categorical_column_with_vocabulary_list(
key="relationship",
vocabulary_list=["Wife", "Own-child", "Husband", "Not-in-family", "Other-relative", "Unmarried"]
)
# Definition Hash features , Exhibition embedding Use
occupation = tf.feature_column.categorical_column_with_hash_bucket(
key="occupation",
hash_bucket_size=1000
)
# age Features are divided into buckets
age_bucket = tf.feature_column.bucketized_column(
source_column=age,
boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65]
)
base_columns = [workclass, education, marital_status, relationship, occupation, age_bucket]
crossed_columns = [
tf.feature_column.crossed_column(
keys=["education", "occupation"], hash_bucket_size=1000
),
tf.feature_column.crossed_column(
keys=[age_bucket, "education", "occupation"], hash_bucket_size=1000
)
]
deep_columns = [
age,
education_num,
capital_gain,
capital_loss,
hours_per_week,
tf.feature_column.indicator_column(workclass), # do one-hot, And then send in dnn layer
tf.feature_column.indicator_column(education),
tf.feature_column.indicator_column(marital_status),
tf.feature_column.indicator_column(relationship),
# Exhibition embedding Use
tf.feature_column.embedding_column(occupation, dimension=8)
]Defining data :
# Defining data
_CSV_COLUMNS = [
"age", "workclass", "fnlwgt", "education", "education_num",
"marital_status", "occupation", "relationship", "race", "sex",
"capital_gain", "capital_loss", "hours_per_week", "native_country", "label"
]
_CSV_COLUMN_DEFAULTS = [
[0], [''], [0], [''], [0],
[''], [''], [''], [''], [''],
[0], [0], [0], [''], ['']
]
_NUM_EXAMPLES = {
"train": 32561,
"validation": 16281
}Defining models :
def create_model(): model = tf.estimator.DNNLinearCombinedClassifier( model_dir="./model/wd/", linear_feature_columns=base_columns + crossed_columns, dnn_feature_columns=deep_columns, dnn_hidden_units=[100, 50], linear_optimizer="Ftrl", dnn_optimizer="Adagrad", n_classes=2, batch_norm=False ) return model
Definition input_fn function :
def input_fn(data_file, num_epochs, shuffle, batch_size):
""" by Estimator Create a input function"""
assert tf.io.gfile.GFile(data_file), "{0} not found.".format(data_file)
def parse_csv(line):
# tf.decode_csv Will be able to csv File conversion to a list of Tensor, One in a row
# record_defaults Used to indicate what to fill in the missing values of each column
columns = tf.io.decode_csv(line, record_defaults=_CSV_COLUMN_DEFAULTS)
features = dict(zip(_CSV_COLUMNS, columns))
labels = features.pop('label')
# tf.equal(x, y) Return to one bool type Tensor, Express x == y, element-wise
# Pay attention to the blank space of the data
return features, tf.equal(labels, ' >50K')
dataset = tf.data.TextLineDataset(data_file).map(parse_csv, num_parallel_calls=5)
if shuffle:
dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'] + _NUM_EXAMPLES['validation'])
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
return datasetmain function :
if __name__ == "__main__":
train_epochs = 20
batch_size = 256
model = create_model()
for n in range(train_epochs):
print("train model start ...")
model.train(input_fn=lambda: input_fn(train_data, train_epochs, True, batch_size))
predict_results = model.predict(input_fn=lambda: input_fn(test_data, train_epochs, False, batch_size))
print("test model start ...")
results = model.evaluate(input_fn=lambda: input_fn(test_data, train_epochs, False, batch_size))
# print(results)
print('{0:-^30}'.format('evaluate at epoch %d' % ((n + 1))))
# results It's a dictionary
print(pd.Series(results).to_frame('values'))Last run 20 individual epoch after , The output is :
-----evaluate at epoch 20----- values accuracy 0.826301 accuracy_baseline 0.763774 auc 0.852878 auc_precision_recall 0.686778 average_loss 0.381445 label/mean 0.236226 loss 0.381446 precision 0.727232 prediction/mean 0.249888 recall 0.423557 global_step 51008.000000 Process finished with exit code 0
Ref
- https://mp.weixin.qq.com/s?__biz=MzIyNTY1MDUwNQ==&mid=2247484238&idx=1&sn=c9700da77cad73f91420fe4309ff0100&chksm=e87d3168df0ab87eb8721dc7220877fb43ae5e66061c7b286fb570720dcc349d3d3d4346eeae&scene=21#wechat_redirect
- https://mp.weixin.qq.com/s/UOuT-E8g22iRRoKC9r23YQ
- https://mp.weixin.qq.com/s/Vur4kvsiXRbfOYyso81WVg
- https://mp.weixin.qq.com/s?__biz=MzI2ODA3NjcwMw==&mid=2247483659&idx=1&sn=deb9c5e22eabd3c52d2418150a40c68a&chksm=eaf452fbdd83dbed0d6de5e847e8569bdc0a75ef6aa23fcaa9c5586a2572cd0e216f499a529b&scene=21#wechat_redirect
- https://liam.page/2019/08/31/a-not-so-simple-introduction-to-FTRL/ FTRL Introduce
- https://zhuanlan.zhihu.com/p/142958834
边栏推荐
- What is pseudo static? How to configure the pseudo static server?
- How to draw the flow chart of C language structure, and how to draw the structure flow chart
- The collection method of penetration test, and which methods can be used to find the real IP
- Old popup explorer Exe has stopped working due to problems. What should I do?
- How to monitor multiple platforms simultaneously when easydss/easygbs platform runs real-time monitoring?
- How to restore the default route for Tencent cloud single network card machine
- Clickhouse synchronous asynchronous executor
- In the post epidemic era, "cloud live broadcast" saves "cloud cultural tourism"?
- Black hat SEO actual combat directory wheel chain generates millions of pages in batch
- C language in DSP (2) -- definition of structure
猜你喜欢

Clickhouse (02) Clickhouse architecture design introduction overview and Clickhouse data slicing design

flutter系列之:flutter中的offstage

多任务视频推荐方案,百度工程师实战经验分享

15+城市道路要素分割应用,用这一个分割模型就够了

Brief ideas and simple cases of JVM tuning - how to tune

Black hat SEO practice: General 301 weight PR hijacking

Modstartcms enterprise content site building system (supporting laravel9) v4.2.0

应用实践 | Apache Doris 整合 Iceberg + Flink CDC 构建实时湖仓一体的联邦查询分析架构

一次 MySQL 误操作导致的事故,「高可用」都顶不住了!

618 promotion: mobile phone brand "immortal fight", high-end market "who dominates the ups and downs"?
随机推荐
Web penetration test - 5. Brute force cracking vulnerability - (1) SSH password cracking
On game safety (I)
Submit sitemap to Baidu
What is FTP? How does the ECS open the FTP protocol?
Several options of F8 are very useful
Easyplayer consumes traffic but does not play video and reports an error libdecoder Wasm404 troubleshooting
How to select a high-performance amd virtual machine? AWS, Google cloud, ucloud, Tencent cloud test big PK
Cross platform RDP protocol, RDP like protocol and non RDP protocol remote software
What should I pay attention to when choosing a data center?
Configuration process of easygbs access to law enforcement recorder
C language linked list points to the next structure pointer, structure and its many small details
Maintain the visibility of data automation: logging, auditing and error handling of the bridge of knowledge and action
Flutter series: offstage in flutter
Black hat SEO practice: General 301 weight PR hijacking
多任务视频推荐方案,百度工程师实战经验分享
Slide left from small window to large windowdispatchframelayout
黑帽SEO实战搜索引擎快照劫持
Psexec right raising
Black hat SEO actual combat search engine snapshot hijacking
How to set up a web server what is the price of the server