当前位置:网站首页>利用minlm比较句子之间的相似度
利用minlm比较句子之间的相似度
2022-08-02 08:28:00 【这个利弗莫尔不太冷】
from turtle import end_fill
from sentence_transformers import SentenceTransformer, util
from textblob import Sentence
from tqdm import tqdm
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
from translate_en_to_ch import get_en_to_zh_model,en_to_ch
import json
def get_sentence_similitay(ori_sentence,contrast_list):
"""得到指定句子之间的相似度"""
#Compute embedding for both lists
print('原始的句子')
print(ori_sentence)
for sample in contrast_list:
embedding_1= model.encode(ori_sentence, convert_to_tensor=True)
embedding_2 = model.encode(sample, convert_to_tensor=True)
sentence_similitay = util.pytorch_cos_sim(embedding_1, embedding_2).item()
print(sample)
print(sentence_similitay)
def get_list():
question_txt = '/cloud/cloud_disk/users/huh/nlp/smart_home/script/emdbedding/huhao.txt'
with open(question_txt,'r') as fp:
contents = fp.readlines()
contrast_list = []
for sample in contents:
contrast_list.append(sample[:-2])
return contrast_list
def get_top10(ori_sentence,contrast_list):
sentences_to_similitary = {}
sentence_similitay_list = []
for sample in tqdm(contrast_list):
embedding_1 = model.encode(ori_sentence, convert_to_tensor=True)
embedding_2 = model.encode(sample, convert_to_tensor=True)
sentence_similitay = abs(float(round(util.pytorch_cos_sim(embedding_1, embedding_2).item(),8)))
sentences_to_similitary[sentence_similitay] = sample
sentence_similitay_list.append(sentence_similitay)
sentence_similitay_list.sort(reverse=True)
for count in sentence_similitay_list[:100]:
print('相似度为')
print(count)
print('对应的句子为')
print(sentences_to_similitary[count])
def is_contain_chinese(check_str):
"""
判断字符串中是否包含中文
:param check_str: {str} 需要检测的字符串
:return: {bool} 包含返回True, 不包含返回False
"""
for ch in check_str:
if u'\u4e00' <= ch <= u'\u9fff':
return True
return False
def check_wrong_format():
with open('/cloud/cloud_disk/users/huh/nlp/smart_home/similitary.csv','r') as fp:
contents =fp.readlines()
all_samples_list = []
for sample in contents:
sample_tmp = sample.split(',')
sentence = ''
if len(all_samples_list)%23==0:
for index in range(1,len(sample_tmp)):
sentence += sample_tmp[index]
all_samples_list.append("{},{}".format(sample_tmp[0],sentence))
# 字符串是中文的情况
try:
if is_contain_chinese(sample_tmp[1]) == True:
if len(sample_tmp)>3:
for index in range(1,len(sample_tmp)-1):
sentence += sample_tmp[index]
all_samples_list.append("{},{},{}".format(sample_tmp[0],sentence,sample_tmp[-1]))
else:
for index in range(1,len(sample_tmp)):
sentence += sample_tmp[index]
all_samples_list.append(",{}".format())
# 字符串是英文的情况
else:
if len(sample_tmp)>4:
for index in range(1,len(sample_tmp)-2):
sentence += sample_tmp[index]
all_samples_list.append("{},{},{}".format(sample_tmp[0],sentence,sample_tmp[-2],sample_tmp[-1]))
else:
all_samples_list.append(sample)
except:
all_samples_list.append(sample)
#coding: utf-8
with open('s2.csv','w',encoding="utf_8_sig") as fp:
fp.writelines(all_samples_list)
def solve_luanma():
import pandas as pd
file_name = '/cloud/cloud_disk/users/huh/nlp/smart_home/similitary.csv'
df = pd.read_csv(file_name, encoding='utf-8')
file_name3 = '/cloud/cloud_disk/users/huh/nlp/smart_home/similitary2.csv'
df.to_csv(file_name3,encoding="utf_8_sig",sep = ',')
def get_en_ch(contrast_list):
"""得到英文对应汉语的字典"""
translation = get_en_to_zh_model()
en_to_ch_dict = {}
for sample in tqdm(contrast_list):
ch = translation(sample, max_length=1024)[0]['translation_text']
print(ch)
tmp_list = ch.split(',')
sentence = ''
for index in range(0,len(tmp_list)):
sentence = sentence + tmp_list[index]
print(sample)
print(tmp_list)
print(sentence)
en_to_ch_dict[sample] = sentence
return en_to_ch_dict
def get_all_samples(contrast_list,en_to_ch,sample_to_id):
sample_to_embedding = {}
for sample in tqdm(contrast_list):
embedding = model.encode(sample, convert_to_tensor=True)
sample_to_embedding[sample] = embedding
xls_list = []
for sample in tqdm(contrast_list):
sentences_to_similitary = {}
sentence_similitay_list = []
for tmp_sample in contrast_list:
embedding_1 = sample_to_embedding[sample]
embedding_2 = sample_to_embedding[tmp_sample]
sentence_similitay = abs(float(round(util.pytorch_cos_sim(embedding_1, embedding_2).item(),8)))
sentences_to_similitary[sentence_similitay] = tmp_sample
sentence_similitay_list.append(sentence_similitay)
sentence_similitay_list = list(set(sentence_similitay_list))
sentence_similitay_list.sort(reverse=True)
xls_list.append('原句子:,{},{},{}\n'.format(sample,' ',sample_to_id[sample]))
xls_list.append('原句子:,{}\n'.format(en_to_ch[sample]))
for count in sentence_similitay_list[1:11]:
xls_list.append(" ,{},{},{}\n".format(sentences_to_similitary[count],count,sample_to_id[sentences_to_similitary[count]]))
xls_list.append(" ,{}\n".format(en_to_ch[sentences_to_similitary[count]]))
xls_list.append("\n")
with open('similitary.csv','w',encoding="utf_8_sig") as fp:
fp.writelines(xls_list)
def get_sample_to_id():
sample_to_id = {}
json_path = '/cloud/cloud_disk/users/huh/nlp/smart_home/script/emdbedding/cattree.json'
with open(json_path, 'r') as f:
json_data = json.load(f)
sample_list = json_data['samples']
for sample in sample_list:
asin = sample['asin']
question_list = sample['QAS']
for tmp in question_list:
question = tmp['question']
sentence = ' '
tmp_list = question.split(',')
for index in range(0,len(tmp_list)):
sentence = sentence + tmp_list[index]
sample_to_id[sentence] = asin
return sample_to_id
def remove_douhao(contrast_list):
end_list = []
for sample in contrast_list:
sentence = ' '
tmp_list = sample.split(',')
for index in range(0,len(tmp_list)):
sentence = sentence + tmp_list[index]
end_list.append(sentence)
return end_list
if __name__ == '__main__':
ori_sentence = 'Has anyone found a good way to remove stains on this cat tree? '
contrast_list = ['Will the cat tower basket fit an adult cat of 18 lbs? ','How big of a cat will this structure hold? One of my cats weight 21 pounds. ','Can these shelves be installed onto concrete walls?','Does the chestnut/natural set in nclude everything in the picture including the planters??']
#get_sentence_similitay(ori_sentence,contrast_list)
contrast_list = get_list()
contrast_list = list(set(contrast_list))
contrast_list = remove_douhao(contrast_list)
en_to_ch = get_en_ch(contrast_list)
sample_to_id = get_sample_to_id()
get_all_samples(contrast_list,en_to_ch,sample_to_id)
边栏推荐
猜你喜欢
52. [Bool type input any non-0 value is not 1 version reason]
Pycharm (1) the basic use of tutorial
shell脚本
How to use postman
openpyxl 单元格合并
pnpm: Introduction
Redisson的看门狗机制
Biotin-EDA|CAS:111790-37-5| Ethylenediamine biotin
EPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network
mysql 中 in 的用法
随机推荐
Flink 监控指南 被动拉取 Rest API
day_05模块
Ansible 学习总结(11)—— task 并行执行之 forks 与 serial 参数详解
为什么都推荐使用wordpress, 而不是 phpcms 这些国内的CMS呢?
RetinaFace: Single-stage Dense Face Localisation in the Wild
Redisson实现分布式锁
unity pdg 设置隐藏不需要的节点以及实现自动勾选自动加载项
MySQL 中 count() 和 count(1) 有什么区别?哪个性能最好?
How Engineers Treat Open Source --- A veteran engineer's heartfelt words
C语言_指针
pnpm:简介
构建Flink第一个应用程序
UVM事务级建模
R语言plotly可视化:plotly可视化回归模型实际值和回归预测值的散点图分析回归模型的预测效能、一个好的模型大部分的散点在对角线附近(predicted vs actual)
下一个排列
那些年我们踩过的 Flink 坑系列
科技云报道:实现元宇宙,英伟达从打造基础建设平台开始
Technology Cloud Report: To realize the metaverse, NVIDIA starts from building an infrastructure platform
Jenkins--基础--6.3--Pipeline--语法--脚本式
Biotin-C6-amine|N-biotinyl-1,6-hexanediamine|CAS: 65953-56-2