当前位置:网站首页>逻辑回归原理
逻辑回归原理
2022-07-27 23:01:00 【樱花的浪漫】
1.sigmoid函数
公式:
自变量取值为任意实数,值域[0,1]
解释:将任意的输入映射到了[0,1]区间 ,我们在线性回归中可以得到一个预测值,再将该值映射到Sigmoid 函数中这样就完成了由值到概率的转换,也就是分类任务
预测函数:![]()
其中,![]()
分类任务:![]()
整合: ![]()
解释:对于二分类任务(0,1),整合后y取0只保留
,y取1只保留![]()

似然函数:
对数似然: 
此时应用梯度上升求最大值,引入![]()
转换为梯度下降任务
求导过程:

参数更新:
多分类的softmax:
2.代码实现
import numpy as np
from utils.features import prepare_for_training
from scipy.optimize import minimize
from utils.hypothesis import sigmoid
class LogisticRegression:
def __init__(self, data, labels, polynomial_degree=0, sinusoid_degree=0, normalize_data=False):
(data_processed, features_mean,
features_deviation) = prepare_for_training(data, polynomial_degree, sinusoid_degree,
normalize_data)
self.data = data_processed
self.labels = labels
self.unique_labels = np.unique(labels)
self.features_mean = features_mean
self.features_deviation = features_deviation
self.polynomial_degree = polynomial_degree
self.sinusoid_degree = sinusoid_degree
self.normalize_data = normalize_data
# 数据预处理
num_unique_labels = len(np.unique(labels))
num_features = self.data.shape[1]
self.theta = np.zeros((num_unique_labels, num_features))
def train(self, n_iterations=500):
cost_histories = []
num_features = self.data.shape[1]
for label_index, unique_label in enumerate(self.unique_labels):
current_initial_theta = np.copy(self.theta[label_index].reshape(num_features, 1))
current_lables = (self.labels == unique_label).astype(float)
(theta, cost_history) = LogisticRegression.gradient_descent(self.data, current_initial_theta,
current_lables, n_iterations)
self.theta[label_index]=theta.T
cost_histories.append(cost_history)
return self.theta,cost_histories
@staticmethod
def gradient_descent(data, current_initial_theta, current_lables, n_iterations):
cost_history = []
num_fratures = data.shape[1]
result = minimize(
# 优化的目标
lambda x: LogisticRegression.cost_function(data, current_lables, x.reshape(num_fratures,1)),
# 初始化的权重参数
current_initial_theta,
# 选择优化策略
method='CG',
# 梯度下降迭代计算公式
jac=lambda x: LogisticRegression.gradient_step(data, current_lables, x.reshape(num_fratures,1)),
callback=lambda x: cost_history.append(
LogisticRegression.cost_function(data, current_lables, x.reshape(num_fratures,1))),
options={
"maxiter": n_iterations
}
)
if not result.success:
raise ArithmeticError('Can not minimize cost function' + result.message)
theta = result.x.reshape(num_fratures,1)
return theta,cost_history
@staticmethod
def cost_function(data, label, theta):
num_examples = data.shape[0]
prediction = LogisticRegression.hypothesis(data, theta)
y_true_cost = np.dot(label[label == 1].T, np.log(prediction[label == 1]))
y_false_cost = np.dot(1 - label[label == 0].T, np.log(1 - prediction[label == 0]))
cost = (-1 / num_examples) * (y_false_cost + y_true_cost)
return cost
@staticmethod
def hypothesis(data, theta):
predictions = sigmoid(np.dot(data, theta))
return predictions
@staticmethod
def gradient_step(data, label, theta):
num_examples = data.shape[0]
prediction = LogisticRegression.hypothesis(data, theta)
label_diff = prediction - label
gradients = (1 / num_examples) * np.dot(data.T, label_diff)
return gradients.T.flatten()
def predict(self,data):
num_examples = data.shape[0]
data_processed = prepare_for_training(data, self.polynomial_degree, self.sinusoid_degree,
self.normalize_data)[0]
prediction = LogisticRegression.hypothesis(data_processed, self.theta.T)
arg = np.argmax(prediction,axis=1)
class_prediction = np.empty(arg.shape,dtype=object)
for index,unique_label in enumerate(self.unique_labels):
class_prediction[arg == index] = unique_label
return class_prediction.reshape((num_examples,1))
边栏推荐
- DEMO:测试接口短时间内接收重复数据创建单据
- 实现ABCD字母递增
- Swoole websocket service
- Swoole定时器
- S-RPN: Sampling-balanced region proposal network for small crop pest detection
- Deepening the concept of linear algebra [23] 01 - points coordinate points and vectors vectors
- Oracle error: ora-01722 invalid number
- 杂谈:一份最初就非常完善的FS跟第一版程序就要求没bug一样不切实际
- Byte flybook Human Resource Kit three sides
- 推荐系统-指标:ctr、cvr
猜你喜欢

接口测试实战项目02:读懂接口测试文档,上手操练

文件系统挂载

3年经验想拿20K,居然面了半个月都没拿到?

Ford SUV "Mustang" officially went offline, safe and comfortable

DEMO:测试接口短时间内接收重复数据创建单据

Sign up now | cloud native technology exchange meetup Guangzhou station has been opened, and I will meet you on August 6!

Fabric2.4.4 version building process (complete process)

Go 语言变量

Matlab drawing - points and vectors: method and source code of vector addition and subtraction

Starfish Os打造的元宇宙生态,跟MetaBell的合作只是开始
随机推荐
[introduction to C language] zzulioj 1026-1030
Tear the source code of gateway by hand, and tear the source code of workflow and load balancing today
比亚迪半导体完成8亿元A+轮融资:30家知名投资机构入局,估值已达102亿元!
推荐系统-模型:wide&deep 模型
110. SAP UI5 FileUploader 控件深入介绍 - 为什么需要一个隐藏的 iframe
oracle分组取最大值
S-RPN: Sampling-balanced region proposal network for small crop pest detection
Rancher2.6 monitoring grafana docking LDAP
Recommended system - indicators: CTR, CVR
Examples of application of JMeter in performance testing
Monitor mouse sideslip (adapt to mobile terminal)
What is the reason for Chinese garbled code when dataworks transmits data to MySQL
华米科技“黄山2号”发布:AI性能提升7倍,功耗降低50%!
Basic learning of cesium
浏览器视频帧操作方法 requestVideoFrameCallback() 简介
Analysis and recurrence of network security vulnerabilities
6月19日上会,中芯国际或创造国内最快上市记录!
Oracle error: ora-01722 invalid number
推荐系统-模型(三):精排模型【LR、GBDT、Wide&Deep、DCN、DIN、DIEN、MMOE、PLE】
Operator depth anatomy