当前位置:网站首页>利用minlm比较句子之间的相似度
利用minlm比较句子之间的相似度
2022-08-02 08:28:00 【这个利弗莫尔不太冷】
from turtle import end_fill
from sentence_transformers import SentenceTransformer, util
from textblob import Sentence
from tqdm import tqdm
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
from translate_en_to_ch import get_en_to_zh_model,en_to_ch
import json
def get_sentence_similitay(ori_sentence,contrast_list):
"""得到指定句子之间的相似度"""
#Compute embedding for both lists
print('原始的句子')
print(ori_sentence)
for sample in contrast_list:
embedding_1= model.encode(ori_sentence, convert_to_tensor=True)
embedding_2 = model.encode(sample, convert_to_tensor=True)
sentence_similitay = util.pytorch_cos_sim(embedding_1, embedding_2).item()
print(sample)
print(sentence_similitay)
def get_list():
question_txt = '/cloud/cloud_disk/users/huh/nlp/smart_home/script/emdbedding/huhao.txt'
with open(question_txt,'r') as fp:
contents = fp.readlines()
contrast_list = []
for sample in contents:
contrast_list.append(sample[:-2])
return contrast_list
def get_top10(ori_sentence,contrast_list):
sentences_to_similitary = {}
sentence_similitay_list = []
for sample in tqdm(contrast_list):
embedding_1 = model.encode(ori_sentence, convert_to_tensor=True)
embedding_2 = model.encode(sample, convert_to_tensor=True)
sentence_similitay = abs(float(round(util.pytorch_cos_sim(embedding_1, embedding_2).item(),8)))
sentences_to_similitary[sentence_similitay] = sample
sentence_similitay_list.append(sentence_similitay)
sentence_similitay_list.sort(reverse=True)
for count in sentence_similitay_list[:100]:
print('相似度为')
print(count)
print('对应的句子为')
print(sentences_to_similitary[count])
def is_contain_chinese(check_str):
"""
判断字符串中是否包含中文
:param check_str: {str} 需要检测的字符串
:return: {bool} 包含返回True, 不包含返回False
"""
for ch in check_str:
if u'\u4e00' <= ch <= u'\u9fff':
return True
return False
def check_wrong_format():
with open('/cloud/cloud_disk/users/huh/nlp/smart_home/similitary.csv','r') as fp:
contents =fp.readlines()
all_samples_list = []
for sample in contents:
sample_tmp = sample.split(',')
sentence = ''
if len(all_samples_list)%23==0:
for index in range(1,len(sample_tmp)):
sentence += sample_tmp[index]
all_samples_list.append("{},{}".format(sample_tmp[0],sentence))
# 字符串是中文的情况
try:
if is_contain_chinese(sample_tmp[1]) == True:
if len(sample_tmp)>3:
for index in range(1,len(sample_tmp)-1):
sentence += sample_tmp[index]
all_samples_list.append("{},{},{}".format(sample_tmp[0],sentence,sample_tmp[-1]))
else:
for index in range(1,len(sample_tmp)):
sentence += sample_tmp[index]
all_samples_list.append(",{}".format())
# 字符串是英文的情况
else:
if len(sample_tmp)>4:
for index in range(1,len(sample_tmp)-2):
sentence += sample_tmp[index]
all_samples_list.append("{},{},{}".format(sample_tmp[0],sentence,sample_tmp[-2],sample_tmp[-1]))
else:
all_samples_list.append(sample)
except:
all_samples_list.append(sample)
#coding: utf-8
with open('s2.csv','w',encoding="utf_8_sig") as fp:
fp.writelines(all_samples_list)
def solve_luanma():
import pandas as pd
file_name = '/cloud/cloud_disk/users/huh/nlp/smart_home/similitary.csv'
df = pd.read_csv(file_name, encoding='utf-8')
file_name3 = '/cloud/cloud_disk/users/huh/nlp/smart_home/similitary2.csv'
df.to_csv(file_name3,encoding="utf_8_sig",sep = ',')
def get_en_ch(contrast_list):
"""得到英文对应汉语的字典"""
translation = get_en_to_zh_model()
en_to_ch_dict = {}
for sample in tqdm(contrast_list):
ch = translation(sample, max_length=1024)[0]['translation_text']
print(ch)
tmp_list = ch.split(',')
sentence = ''
for index in range(0,len(tmp_list)):
sentence = sentence + tmp_list[index]
print(sample)
print(tmp_list)
print(sentence)
en_to_ch_dict[sample] = sentence
return en_to_ch_dict
def get_all_samples(contrast_list,en_to_ch,sample_to_id):
sample_to_embedding = {}
for sample in tqdm(contrast_list):
embedding = model.encode(sample, convert_to_tensor=True)
sample_to_embedding[sample] = embedding
xls_list = []
for sample in tqdm(contrast_list):
sentences_to_similitary = {}
sentence_similitay_list = []
for tmp_sample in contrast_list:
embedding_1 = sample_to_embedding[sample]
embedding_2 = sample_to_embedding[tmp_sample]
sentence_similitay = abs(float(round(util.pytorch_cos_sim(embedding_1, embedding_2).item(),8)))
sentences_to_similitary[sentence_similitay] = tmp_sample
sentence_similitay_list.append(sentence_similitay)
sentence_similitay_list = list(set(sentence_similitay_list))
sentence_similitay_list.sort(reverse=True)
xls_list.append('原句子:,{},{},{}\n'.format(sample,' ',sample_to_id[sample]))
xls_list.append('原句子:,{}\n'.format(en_to_ch[sample]))
for count in sentence_similitay_list[1:11]:
xls_list.append(" ,{},{},{}\n".format(sentences_to_similitary[count],count,sample_to_id[sentences_to_similitary[count]]))
xls_list.append(" ,{}\n".format(en_to_ch[sentences_to_similitary[count]]))
xls_list.append("\n")
with open('similitary.csv','w',encoding="utf_8_sig") as fp:
fp.writelines(xls_list)
def get_sample_to_id():
sample_to_id = {}
json_path = '/cloud/cloud_disk/users/huh/nlp/smart_home/script/emdbedding/cattree.json'
with open(json_path, 'r') as f:
json_data = json.load(f)
sample_list = json_data['samples']
for sample in sample_list:
asin = sample['asin']
question_list = sample['QAS']
for tmp in question_list:
question = tmp['question']
sentence = ' '
tmp_list = question.split(',')
for index in range(0,len(tmp_list)):
sentence = sentence + tmp_list[index]
sample_to_id[sentence] = asin
return sample_to_id
def remove_douhao(contrast_list):
end_list = []
for sample in contrast_list:
sentence = ' '
tmp_list = sample.split(',')
for index in range(0,len(tmp_list)):
sentence = sentence + tmp_list[index]
end_list.append(sentence)
return end_list
if __name__ == '__main__':
ori_sentence = 'Has anyone found a good way to remove stains on this cat tree? '
contrast_list = ['Will the cat tower basket fit an adult cat of 18 lbs? ','How big of a cat will this structure hold? One of my cats weight 21 pounds. ','Can these shelves be installed onto concrete walls?','Does the chestnut/natural set in nclude everything in the picture including the planters??']
#get_sentence_similitay(ori_sentence,contrast_list)
contrast_list = get_list()
contrast_list = list(set(contrast_list))
contrast_list = remove_douhao(contrast_list)
en_to_ch = get_en_ch(contrast_list)
sample_to_id = get_sample_to_id()
get_all_samples(contrast_list,en_to_ch,sample_to_id)
边栏推荐
猜你喜欢
Qt读取文件中内容(通过判断GBK UTF-8格式进行读取显示)
HCIP笔记十六天
What is NoSQL?Databases for the cloud-scale future
Three types of [OC learning notes] Block
C Language Basics_Union
Biotin-C6-amine|N-biotinyl-1,6-hexanediamine|CAS: 65953-56-2
How Engineers Treat Open Source --- A veteran engineer's heartfelt words
如何建立私域流量?私域流量对企业有什么好处?
[ansible] playbook explains the execution steps in combination with the project
postman使用方法
随机推荐
LeetCode_2357_使数组种所有元素都等于零
day——05 迭代器,生成器
day_05 time 模块
etcd implements large-scale service governance application combat
Database triggers and transactions
下一个排列
location对象,navigator对象,history对象学习
Seleniu screenshots code and assign name to the picture
USACO美国信息学奥赛竞赛12月份开赛,中国学生备赛指南
R语言plotly可视化:plotly可视化回归模型实际值和回归预测值的散点图分析回归模型的预测效能、一个好的模型大部分的散点在对角线附近(predicted vs actual)
TiFlash 存储层概览
Technology Cloud Report: To realize the metaverse, NVIDIA starts from building an infrastructure platform
Figure robot software digital twin station oil and gas pipelines, oil and gas transportation control platform
ip地址那点事(二)
Redisson报异常attempt to unlock lock, not locked by current thread by node id解决方案
uvm-phase机制
MySQL 中 count() 和 count(1) 有什么区别?哪个性能最好?
Analysis of software testing technology How far is Turing test from us
OneinStack多版本PHP共存
C语言基础_结构体