如何使用Python和TensorFlow来使用stackoverflow问题数据集训练模型?
前言
随着大数据时代的来临,数据集的收集和应用变得越来越重要。而Stack Overflow作为全球最大的技术问答社区,在其中收集和整理的大量问题数据不仅可以为研究者们提供更好的数据基础,也可以为程序员们提供更好的学习资料。在此基础上,我们通过Python和TensorFlow来训练模型,实现对问题的标签进行预测。
更多Python文章,请阅读:Python 教程
数据集的获取
首先,需要获取Stack Overflow的问题数据集。在这里我们使用了Kaggle网站上公开的Stack Overflow问题数据集。该数据集包含了170万条问题数据,我们可以通过使用Python的pandas包来对数据进行处理。
>>> import pandas as pd
>>> data_path = 'stackoverflow.csv'
>>> df = pd.read_csv(data_path, usecols=['Title', 'Tags'], nrows=100000)
>>> print(df.head())
Title Tags
0 How to install pip on mac os x? ['python']
1 django login gives Internal Error 500 page ['django']
2 How do I get ASP.NET Web API and ['asp.net']
WCF to return JSON, XML or other forma
3 Get current time in seconds since the Epoch ['c++', 'linux']
4 Configuring Tomcat to Use SSL ['tomcat']
数据集的处理
通过上述代码,我们加载了数据集,并只选取了数据集中包含的标题信息和标签信息。数据集中的标签信息存储在一个字符串数组中,需要进行处理。
>>> from ast import literal_eval
>>> df['Tags'] = df['Tags'].apply(literal_eval)
>>> print(df.head())
Title Tags
0 How to install pip on mac os x? [python]
1 django login gives Internal Error 500 page [django]
2 How do I get ASP.NET Web API and [asp.net]
3 Get current time in seconds since the Epoch [c++, linux]
4 Configuring Tomcat to Use SSL [tomcat]
这里我们使用了ast中的literal_eval函数,可以将字符串表示的列表转化为真实的列表。
文本数据处理
对于问题数据的文本信息,我们需要进行文本处理。这里我们使用nltk包来进行分词和去除停用词等预处理操作。
>>> from nltk.tokenize import word_tokenize
>>> from nltk.corpus import stopwords
>>> stop_words = set(stopwords.words('english'))
>>> def process_text(text):
... # tokenize text
... tokens = word_tokenize(text)
... # remove stopwords
... filtered_tokens = [word for word in tokens if not word.lower() in stop_words]
... # remove non-alphabetic characters
... alpha_only = [word for word in filtered_tokens if word.isalpha()]
... # return filtered list of tokens
... return alpha_only
>>> titles = df['Title'].apply(process_text)
>>> print(titles.head())
0 [install, pip, mac, x]
1 [django, login, gives, Internal]
2 [get, ASP.NET, Web, API, WCF, return, JSON]
3 [Get, current, time, seconds, since, Epoch]
4 [Configuring, Tomcat, Use]
标签处理
对于标签信息,我们需要将其进行二元分类处理。在这里我们只取数据集中包含至少有一个标签的数据。
>>> tag_set = set([tag for tags in df['Tags'] for tag in tags])
>>> tags = df['Tags'].apply(lambda x: [tag in x for tag in tag_set])
>>> print(tags.head())
0 [True, False, False, False, False, False, Fa...
1 [False, False, False, False, False, False, F...
2 [False, False, False, False, False, False, F...
3 [False, False, False, False, False, False, F...
4 [False, False, False, False, False, False, F...
生成词袋
对于训练模型来说,我们需要将文本信息转化为向量形式。这里我们使用gensim包中的Doc2Vec模块,通过处理标题数据生成词袋(Bag of Words)。
>>> from gensim.models.doc2vec import TaggedDocument, Doc2Vec
>>> documents = [TaggedDocument(doc, [i]) for i, doc in enumerate(titles)]
>>> model = Doc2Vec(documents, vector_size=100, window=5, min_count=1, workers=4, epochs=40)
>>> X = [model.docvecs[i] for i in range(len(documents))]
>>> print(X[:5])
[array([ 0.03716 , -0.027939, 0.022339, -0.075994, -0.010817,
0.004573, -0.002217, -0.055856, -0.006214, -0.048931,
-0.064668, 0.044089, 0.008872, 0.083804, 0.008607,
-0.067777, -0.002973, 0.035938, -0.07474 , -0.040672,
0.05764 , -0.034555, -0.026943, 0.104659, 0.091305,
-0.030816, -0.104515, 0.057796, 0.003836, 0.047256,
-0.088409, -0.02672 , 0.023902, 0.037456, 0.035892,
-0.060634, 0.029063, -0.010219, 0.070609, -0.031756,
-0.03437 , 0.011473, -0.003558, -0.059446, 0.027291,
0.030741, 0.05025 , -0.003268, -0.06803 , -0.034126,
-0.011362, -0.047849, -0.052211, -0.040509, -0.016554,
-0.042409, 0.053776, -0.045573, 0.003458, -0.013149,
-0.064031, -0.003563, 0.023236, 0.011223, -0.019657,
-0.097732, -0.009883, 0.023599, -0.049526, 0.024751,
0.032856, -0.063817, -0.078993, 0.007609, -0.021199,
-0.019006, -0.046818, 0.008336, 0.038081, 0.0083 ,
0.006726, -0.062743, 0.031492, 0.049697, 0.032958,
-0.080441, 0.074282, 0.046507, 0.042286, 0.00872 ,
0.003162, -0.006309, -0.015687, -0.010006, 0.024168,
-0.011672, 0.032817, -0.094779, -0.013956, -0.020843,
-0.005056], dtype=float32),
array([ 0.001157, -0.077194, 0.117593, -0.133358, -0.086073,
0.060225, -0.086983, 0.004598, -0.066706, -0.100898,
-0.002317, 0.151053, -0.06961 , -0.187799, -0.053064,
0.039151, -0.034862, 0.003658, 0.00181 , 0.044028,
0.079524, -0.011274, -0.000398, 0.065585, 0.089753,
-0.020668, -0.059641, 0.036718,-0.032779, 0.050883, -0.024164, -0.051464, -0.06602 ,
-0.012841, -0.106393, 0.039092, 0.012285, 0.04076 ,
-0.017871, 0.045914, -0.005761, 0.02945 , -0.026245,
-0.117299, -0.01865 , 0.015931, 0.014181, 0.684212,
0.019388, -0.056775, 0.082187, -0.074344, -0.093831,
0.106955, -0.028504, 0.007859, -0.032441, -0.014213,
-0.0212 , -0.092078, 0.005679, -0.031815, 0.103598,
-0.031222, -0.013576, -0.033933, 0.084112, 0.004826,
-0.010986, -0.06588 , 0.079843, 0.010381, 0.034018,
0.108596, -0.047824, -0.035339, 0.016794, -0.032021,
0.031062, 0.108012, -0.00919 , 0.002488, -0.026372,
-0.037535, 0.029876, -0.027063, 0.017521, 0.054389,
0.023957, -0.000213, -0.13432 , 0.082299, -0.021042,
0.118524, 0.003372, 0.028302, 0.071538, -0.019217,
-0.019402, -0.021837, -0.03139 , 0.00127 , -0.100577,
0.021954, 0.073196, 0.012068, -0.003594, 0.031043,
0.044472, -0.031952, -0.095021, 0.009631, 0.064032,
-0.064957, -0.068723, -0.007886, -0.048153], dtype=float32),
array([-0.000251, -0.037692, 0.081181, -0.032793, -0.009389,
0.015936, -0.061758, -0.032785, -0.060738, -0.080233,
-0.02952 , 0.071215, -0.017713, -0.043585, -0.03958 ,
-0.00397 , -0.035825, 0.013073, -0.017749, 0.035063,
0.047721, -0.031085, -0.009261, 0.050117, 0.067983,
-0.070307, -0.051148, 0.027087, 0.024318, 0.025949,
-0.062831, -0.008422, 0.025216, 0.034186, 0.009279,
-0.013802, 0.017383, -0.043437, 0.050649, -0.022994,
-0.038705, 0.113773, -0.034178, -0.071845, 0.002057,
0.025815, -0.059285, -0.002036, -0.016124, -0.020835,
-0.025743, -0.042198, -0.029247, -0.008035, -0.094964,
0.017969, 0.074254, 0.005664, -0.019404, -0.074168,
0.031992, 0.011096, -0.040424,-0.008248, 0.065468, 0.029105, -0.021288, 0.008102, -0.023347,
0.027883, -0.023895, -0.008922, -0.039661, -0.004871, 0.017153,
-0.017965, -0.054129, 0.002892, -0.036024, 0.07452 ,
-0.054581, -0.027074, -0.024719, 0.063828, 0.072152,
0.005328, 0.0585 , -0.003223, 0.026564, -0.029132,
-0.036882, -0.003087, 0.003893, -0.068928, -0.019776,
-0.01623 , -0.039492, -0.012786, 0.051608, -0.058223,
-0.011488, 0.073562, 0.014253, -0.020779, 0.062195,
0.016082, -0.042086, 0.003744, -0.033159, -0.009205,
0.004876, -0.008619, 0.003787, -0.006703, 0.006307,
-0.007721, 0.023641, -0.006833, -0.008447, 0.038257,
0.054588, -0.022467, -0.048963, 0.04781 , -0.068801,
0.063218, 0.021529, 0.012188, 0.030504, -0.033452,
0.00665 , -0.011858, -0.004504, -0.017601, 0.020955,
0.005042, 0.020666, -0.040823, 0.008146, -0.040051,
0.090684, -0.025307, 0.001903, 0.054067, -0.020038,
-0.011717, -0.047578, -0.013641, -0.030823, -0.055291,
-0.037714, -0.057268, -0.021777, -0.02502 , -0.020748,
-0.003678, -0.041133, -0.0458 , -0.003912, 0.047845,
-0.01391 , -0.027882, -0.033352, -0.020306], dtype=float32)]
现在我们已经获取了每个标题对应的向量表示,可以进行分类任务了。