当前位置:网站首页>利用minlm比较句子之间的相似度
利用minlm比较句子之间的相似度
2022-08-02 08:28:00 【这个利弗莫尔不太冷】
from turtle import end_fill
from sentence_transformers import SentenceTransformer, util
from textblob import Sentence
from tqdm import tqdm
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
from translate_en_to_ch import get_en_to_zh_model,en_to_ch
import json
def get_sentence_similitay(ori_sentence,contrast_list):
"""得到指定句子之间的相似度"""
#Compute embedding for both lists
print('原始的句子')
print(ori_sentence)
for sample in contrast_list:
embedding_1= model.encode(ori_sentence, convert_to_tensor=True)
embedding_2 = model.encode(sample, convert_to_tensor=True)
sentence_similitay = util.pytorch_cos_sim(embedding_1, embedding_2).item()
print(sample)
print(sentence_similitay)
def get_list():
question_txt = '/cloud/cloud_disk/users/huh/nlp/smart_home/script/emdbedding/huhao.txt'
with open(question_txt,'r') as fp:
contents = fp.readlines()
contrast_list = []
for sample in contents:
contrast_list.append(sample[:-2])
return contrast_list
def get_top10(ori_sentence,contrast_list):
sentences_to_similitary = {}
sentence_similitay_list = []
for sample in tqdm(contrast_list):
embedding_1 = model.encode(ori_sentence, convert_to_tensor=True)
embedding_2 = model.encode(sample, convert_to_tensor=True)
sentence_similitay = abs(float(round(util.pytorch_cos_sim(embedding_1, embedding_2).item(),8)))
sentences_to_similitary[sentence_similitay] = sample
sentence_similitay_list.append(sentence_similitay)
sentence_similitay_list.sort(reverse=True)
for count in sentence_similitay_list[:100]:
print('相似度为')
print(count)
print('对应的句子为')
print(sentences_to_similitary[count])
def is_contain_chinese(check_str):
"""
判断字符串中是否包含中文
:param check_str: {str} 需要检测的字符串
:return: {bool} 包含返回True, 不包含返回False
"""
for ch in check_str:
if u'\u4e00' <= ch <= u'\u9fff':
return True
return False
def check_wrong_format():
with open('/cloud/cloud_disk/users/huh/nlp/smart_home/similitary.csv','r') as fp:
contents =fp.readlines()
all_samples_list = []
for sample in contents:
sample_tmp = sample.split(',')
sentence = ''
if len(all_samples_list)%23==0:
for index in range(1,len(sample_tmp)):
sentence += sample_tmp[index]
all_samples_list.append("{},{}".format(sample_tmp[0],sentence))
# 字符串是中文的情况
try:
if is_contain_chinese(sample_tmp[1]) == True:
if len(sample_tmp)>3:
for index in range(1,len(sample_tmp)-1):
sentence += sample_tmp[index]
all_samples_list.append("{},{},{}".format(sample_tmp[0],sentence,sample_tmp[-1]))
else:
for index in range(1,len(sample_tmp)):
sentence += sample_tmp[index]
all_samples_list.append(",{}".format())
# 字符串是英文的情况
else:
if len(sample_tmp)>4:
for index in range(1,len(sample_tmp)-2):
sentence += sample_tmp[index]
all_samples_list.append("{},{},{}".format(sample_tmp[0],sentence,sample_tmp[-2],sample_tmp[-1]))
else:
all_samples_list.append(sample)
except:
all_samples_list.append(sample)
#coding: utf-8
with open('s2.csv','w',encoding="utf_8_sig") as fp:
fp.writelines(all_samples_list)
def solve_luanma():
import pandas as pd
file_name = '/cloud/cloud_disk/users/huh/nlp/smart_home/similitary.csv'
df = pd.read_csv(file_name, encoding='utf-8')
file_name3 = '/cloud/cloud_disk/users/huh/nlp/smart_home/similitary2.csv'
df.to_csv(file_name3,encoding="utf_8_sig",sep = ',')
def get_en_ch(contrast_list):
"""得到英文对应汉语的字典"""
translation = get_en_to_zh_model()
en_to_ch_dict = {}
for sample in tqdm(contrast_list):
ch = translation(sample, max_length=1024)[0]['translation_text']
print(ch)
tmp_list = ch.split(',')
sentence = ''
for index in range(0,len(tmp_list)):
sentence = sentence + tmp_list[index]
print(sample)
print(tmp_list)
print(sentence)
en_to_ch_dict[sample] = sentence
return en_to_ch_dict
def get_all_samples(contrast_list,en_to_ch,sample_to_id):
sample_to_embedding = {}
for sample in tqdm(contrast_list):
embedding = model.encode(sample, convert_to_tensor=True)
sample_to_embedding[sample] = embedding
xls_list = []
for sample in tqdm(contrast_list):
sentences_to_similitary = {}
sentence_similitay_list = []
for tmp_sample in contrast_list:
embedding_1 = sample_to_embedding[sample]
embedding_2 = sample_to_embedding[tmp_sample]
sentence_similitay = abs(float(round(util.pytorch_cos_sim(embedding_1, embedding_2).item(),8)))
sentences_to_similitary[sentence_similitay] = tmp_sample
sentence_similitay_list.append(sentence_similitay)
sentence_similitay_list = list(set(sentence_similitay_list))
sentence_similitay_list.sort(reverse=True)
xls_list.append('原句子:,{},{},{}\n'.format(sample,' ',sample_to_id[sample]))
xls_list.append('原句子:,{}\n'.format(en_to_ch[sample]))
for count in sentence_similitay_list[1:11]:
xls_list.append(" ,{},{},{}\n".format(sentences_to_similitary[count],count,sample_to_id[sentences_to_similitary[count]]))
xls_list.append(" ,{}\n".format(en_to_ch[sentences_to_similitary[count]]))
xls_list.append("\n")
with open('similitary.csv','w',encoding="utf_8_sig") as fp:
fp.writelines(xls_list)
def get_sample_to_id():
sample_to_id = {}
json_path = '/cloud/cloud_disk/users/huh/nlp/smart_home/script/emdbedding/cattree.json'
with open(json_path, 'r') as f:
json_data = json.load(f)
sample_list = json_data['samples']
for sample in sample_list:
asin = sample['asin']
question_list = sample['QAS']
for tmp in question_list:
question = tmp['question']
sentence = ' '
tmp_list = question.split(',')
for index in range(0,len(tmp_list)):
sentence = sentence + tmp_list[index]
sample_to_id[sentence] = asin
return sample_to_id
def remove_douhao(contrast_list):
end_list = []
for sample in contrast_list:
sentence = ' '
tmp_list = sample.split(',')
for index in range(0,len(tmp_list)):
sentence = sentence + tmp_list[index]
end_list.append(sentence)
return end_list
if __name__ == '__main__':
ori_sentence = 'Has anyone found a good way to remove stains on this cat tree? '
contrast_list = ['Will the cat tower basket fit an adult cat of 18 lbs? ','How big of a cat will this structure hold? One of my cats weight 21 pounds. ','Can these shelves be installed onto concrete walls?','Does the chestnut/natural set in nclude everything in the picture including the planters??']
#get_sentence_similitay(ori_sentence,contrast_list)
contrast_list = get_list()
contrast_list = list(set(contrast_list))
contrast_list = remove_douhao(contrast_list)
en_to_ch = get_en_ch(contrast_list)
sample_to_id = get_sample_to_id()
get_all_samples(contrast_list,en_to_ch,sample_to_id)
边栏推荐
- [OC学习笔记]weak的实现原理
- 【开源项目】X-TRACK源码分析
- 第3周学习:ResNet+ResNeXt
- PostgreSQL学习总结(11)—— PostgreSQL 常用的高可用集群方案
- 工程师如何对待开源 --- 一个老工程师的肺腑之言
- etcd实现大规模服务治理应用实战
- Redisson distributed lock source code analysis for high-level use of redis
- 普林斯顿微积分读本03第二章--编程实现函数图像绘制、三角学回顾
- Button to control the running water light (timer)
- How to use postman
猜你喜欢
随机推荐
unity pdg 设置隐藏不需要的节点以及实现自动勾选自动加载项
普林斯顿微积分读本03第二章--编程实现函数图像绘制、三角学回顾
【论文阅读】Distilling the Knowledge in a Neural Network
UVM信息服务机制
C语言_条件编译
Jenkins--部署--3.1--代码提交自动触发jenkins--方式1
ip地址那点事(二)
为什么都推荐使用wordpress, 而不是 phpcms 这些国内的CMS呢?
redis-desktop-manager下载安装
Gorilla Mux 和 GORM 的使用方法
Scala类型转换
What is NoSQL?Databases for the cloud-scale future
Qt读取文件中内容(通过判断GBK UTF-8格式进行读取显示)
知识点滴 - 为什么一般不用铜锅做菜
openpyxl 单元格合并
UVM事务级建模
(Note) AXIS ACASIS DT-3608 Dual-bay Hard Disk Array Box RAID Setting
C语言_指针
postman下载安装汉化及使用
【开源项目】X-TRACK源码分析


![[OC学习笔记]ARC与引用计数](/img/56/033cfc15954567d63d987d91ca8d63.png)




![[OC学习笔记]Block三种类型](/img/40/edf59e6e68891ea7c9ab0481fe7bfc.png)

