# 使用Python从头开始手写回归树

``````import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
``````

``````def f(x):
mu, sigma = 0, 1.5
return -x**2 + x + 5 + np.random.normal(mu, sigma, 1)

num_points = 300
np.random.seed(1)

x = np.random.uniform(-2, 5, num_points)
y = np.array( [f(i) for i in x] )

plt.scatter(x, y, s = 5)
``````

## 回归树

``````threshold = 1.5

low = np.take(y, np.where(x < threshold))
high = np.take(y, np.where(x > threshold))

plt.scatter(x, y, s = 5, label = 'Data')
plt.plot([threshold]*2, [-16, 10], 'b--', label = 'Threshold line')
plt.plot([-2, threshold], [low.mean()]*2, 'r--', label = 'Left child prediction line')
plt.plot([threshold, 5], [high.mean()]*2, 'r--', label = 'Right child prediction line')
plt.plot([-2, 5], [y.mean()]*2, 'g--', label = 'Node prediction line')
plt.legend()
``````

``````def SSR(r, y):
return np.sum( (r - y)**2 )

SSRs, thresholds = [], []
for i in range(len(x) - 1):
threshold = x[i:i+2].mean()

low = np.take(y, np.where(x < threshold))
high = np.take(y, np.where(x > threshold))

guess_low = low.mean()
guess_high = high.mean()

SSRs.append(SSR(low, guess_low) + SSR(high, guess_high))
thresholds.append(threshold)

print('Minimum residual is: {:.2f}'.format(min(SSRs)))
print('Corresponding threshold value is: {:.4f}'.format(thresholds[SSRs.index(min(SSRs))]))
``````

``````df = pd.DataFrame(zip(x, y.squeeze()), columns = ['x', 'y'])

def find_threshold(df, plot = False):
SSRs, thresholds = [], []
for i in range(len(df) - 1):
threshold = df.x[i:i+2].mean()

low = df[(df.x <= threshold)]
high = df[(df.x > threshold)]

guess_low = low.y.mean()
guess_high = high.y.mean()

SSRs.append(SSR(low.y.to_numpy(), guess_low) + SSR(high.y.to_numpy(), guess_high))
thresholds.append(threshold)

if plot:
plt.scatter(thresholds, SSRs, s = 3)
plt.show()

return thresholds[SSRs.index(min(SSRs))]
``````

## 创建子节点

``````class TreeNode():
def __init__(self, threshold, pred):
self.threshold = threshold
self.pred = pred
self.left = None
self.right = None

def create_nodes(tree, df, stop):
low = df[df.x <= tree.threshold]
high = df[df.x > tree.threshold]

if len(low) > stop:
threshold = find_threshold(low)
tree.left = TreeNode(threshold, low.y.mean())
create_nodes(tree.left, low, stop)

if len(high) > stop:
threshold = find_threshold(high)
tree.right = TreeNode(threshold, high.y.mean())
create_nodes(tree.right, high, stop)

threshold = find_threshold(df)
tree = TreeNode(threshold, df.y.mean())

create_nodes(tree, df, 5)
``````

``````plt.scatter(x, y, s = 0.5, label = 'Data')
plt.plot([tree.threshold]*2, [-16, 10], 'r--',
label = 'Root threshold')
plt.plot([tree.right.threshold]*2, [-16, 10], 'g--',
label = 'Right node threshold')
plt.plot([tree.threshold, tree.right.threshold],
[tree.right.left.pred]*2,
'g', label = 'Right node prediction')
plt.plot([tree.left.threshold]*2, [-16, 10], 'm--',
label = 'Left node threshold')
plt.plot([tree.left.threshold, tree.threshold],
[tree.left.right.pred]*2,
'm', label = 'Left node prediction')
plt.plot([tree.left.left.threshold]*2, [-16, 10], 'k--',
label = 'Second Left node threshold')
plt.legend()
``````

``````tree.left.right.left.left
``````

## 预测

``````def predict(x):
curr_node = tree
result = None
while True:
if x <= curr_node.threshold:
if curr_node.left: curr_node = curr_node.left
else:
break
elif x > curr_node.threshold:
if curr_node.right: curr_node = curr_node.right
else:
break

return curr_node.pred
``````

``````predict(3)
# -1.23741
``````

## 计算误差

``````def RSE(y, g):
return sum(np.square(y - g)) / sum(np.square(y - 1 / len(y)*sum(y)))

x_val = np.random.uniform(-2, 5, 50)
y_val = np.array( [f(i) for i in x_val] ).squeeze()

tr_preds = np.array( [predict(i) for i in df.x] )
val_preds = np.array( [predict(i) for i in x_val] )
print('Training error: {:.4f}'.format(RSE(df.y, tr_preds)))
print('Validation error: {:.4f}'.format(RSE(y_val, val_preds)))

``````

## 更深入的模型

``````def f(x):
mu, sigma = 0, 0.5
if x < 3: return 1 + np.random.normal(mu, sigma, 1)
elif x >= 3 and x < 6: return 9 + np.random.normal(mu, sigma, 1)
elif x >= 6: return 5 + np.random.normal(mu, sigma, 1)

np.random.seed(1)

x = np.random.uniform(0, 10, num_points)
y = np.array( [f(i) for i in x] )

plt.scatter(x, y, s = 5)

``````

``````import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

#===================================================Create Data
def f(x):
mu, sigma = 0, 1.5
return -x**2 + x + 5 + np.random.normal(mu, sigma, 1)

np.random.seed(1)

x = np.random.uniform(-2, 5, 300)
y = np.array( [f(i) for i in x] )

p = x.argsort()
x = x[p]
y = y[p]

#===================================================Calculate Thresholds
def SSR(r, y): #send numpy array
return np.sum( (r - y)**2 )

SSRs, thresholds = [], []
for i in range(len(x) - 1):
threshold = x[i:i+2].mean()

low = np.take(y, np.where(x < threshold))
high = np.take(y, np.where(x > threshold))

guess_low = low.mean()
guess_high = high.mean()

SSRs.append(SSR(low, guess_low) + SSR(high, guess_high))
thresholds.append(threshold)

#===================================================Animated Plot
fig, (ax1, ax2) = plt.subplots(2,1, sharex = True)
x_data, y_data = [], []
x_data2, y_data2 = [], []
ln, = ax1.plot([], [], 'r--')
ln2, = ax2.plot(thresholds, SSRs, 'ro', markersize = 2)
line = [ln, ln2]

def init():
ax1.scatter(x, y, s = 3)
ax1.title.set_text('Trying Different Thresholds')
ax2.title.set_text('Threshold vs SSR')
ax1.set_ylabel('y values')
ax2.set_xlabel('Threshold')
ax2.set_ylabel('SSR')
return line

def update(frame):
x_data = [x[frame:frame+2].mean()] * 2
y_data = [min(y), max(y)]
line[0].set_data(x_data, y_data)

x_data2.append(thresholds[frame])
y_data2.append(SSRs[frame])
line[1].set_data(x_data2, y_data2)
return line

ani = FuncAnimation(fig, update, frames = 298,
init_func = init, blit = True)
plt.show()
``````

https://avoid.overfit.cn/post/68d76a2540894366bb7033ff120a30d6