Commit 12cc9408 authored by Evan Shelhamer's avatar Evan Shelhamer
Browse files

Merge pull request #2493 from longjon/sketchy-cuda-kernel-loop

Fix dangerous state in pooling and LRN CUDA kernels -- thanks @gustavla for the report in #2145
parents 9eb59787 5a6b0d68
......@@ -7,44 +7,46 @@
namespace caffe {
template <typename Dtype>
__global__ void LRNFillScale(const int nthreads, const Dtype* in,
__global__ void LRNFillScale(const int nthreads, const Dtype* const in,
const int num, const int channels, const int height,
const int width, const int size, const Dtype alpha_over_size,
const Dtype k, Dtype* scale) {
const Dtype k, Dtype* const scale) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local offset
int w = index % width;
int h = (index / width) % height;
int n = index / width / height;
int offset = (n * channels * height + h) * width + w;
int step = height * width;
in += offset;
scale += offset;
const int w = index % width;
const int h = (index / width) % height;
const int n = index / width / height;
const int offset = (n * channels * height + h) * width + w;
const int step = height * width;
const Dtype* const in_off = in + offset;
Dtype* const scale_off = scale + offset;
int head = 0;
int pre_pad = (size - 1) / 2;
int post_pad = size - pre_pad - 1;
const int pre_pad = (size - 1) / 2;
const int post_pad = size - pre_pad - 1;
Dtype accum_scale = 0;
// fill the scale at [n, :, h, w]
// accumulate values
while (head < post_pad && head < channels) {
accum_scale += in[head * step] * in[head * step];
accum_scale += in_off[head * step] * in_off[head * step];
++head;
}
// both add and subtract
while (head < channels) {
accum_scale += in[head * step] * in[head * step];
accum_scale += in_off[head * step] * in_off[head * step];
if (head - size >= 0) {
accum_scale -= in[(head - size) * step] * in[(head - size) * step];
accum_scale -= in_off[(head - size) * step]
* in_off[(head - size) * step];
}
scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
scale_off[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
++head;
}
// subtract only
while (head < channels + post_pad) {
if (head - size >= 0) {
accum_scale -= in[(head - size) * step] * in[(head - size) * step];
accum_scale -= in_off[(head - size) * step]
* in_off[(head - size) * step];
}
scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
scale_off[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
++head;
}
}
......@@ -68,8 +70,8 @@ void LRNLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
// TODO: check if it would be faster to just put it into the previous kernel.
template <typename Dtype>
__global__ void LRNComputeOutput(const int nthreads, const Dtype* in,
const Dtype* scale, const Dtype negative_beta, Dtype* out) {
__global__ void LRNComputeOutput(const int nthreads, const Dtype* const in,
const Dtype* const scale, const Dtype negative_beta, Dtype* const out) {
CUDA_KERNEL_LOOP(index, nthreads) {
out[index] = in[index] * pow(scale[index], negative_beta);
}
......@@ -118,56 +120,58 @@ void LRNLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
}
template <typename Dtype>
__global__ void LRNComputeDiff(const int nthreads, const Dtype* bottom_data,
const Dtype* top_data, const Dtype* scale, const Dtype* top_diff,
__global__ void LRNComputeDiff(const int nthreads,
const Dtype* const bottom_data, const Dtype* const top_data,
const Dtype* const scale, const Dtype* const top_diff,
const int num, const int channels, const int height,
const int width, const int size, const Dtype negative_beta,
const Dtype cache_ratio,
Dtype* bottom_diff) {
const Dtype cache_ratio, Dtype* const bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local offset
int w = index % width;
int h = (index / width) % height;
int n = index / width / height;
int offset = (n * channels * height + h) * width + w;
int step = height * width;
bottom_data += offset;
top_data += offset;
scale += offset;
top_diff += offset;
bottom_diff += offset;
const int w = index % width;
const int h = (index / width) % height;
const int n = index / width / height;
const int offset = (n * channels * height + h) * width + w;
const int step = height * width;
const Dtype* const bottom_off = bottom_data + offset;
const Dtype* const top_off = top_data + offset;
const Dtype* const scale_off = scale + offset;
const Dtype* const top_diff_off = top_diff + offset;
Dtype* const bottom_diff_off = bottom_diff + offset;
int head = 0;
int pre_pad = size - (size + 1) / 2;
int post_pad = size - pre_pad - 1;
const int pre_pad = size - (size + 1) / 2;
const int post_pad = size - pre_pad - 1;
Dtype accum_ratio = 0;
// accumulate values
while (head < post_pad && head < channels) {
accum_ratio += top_diff[head * step] * top_data[head * step] /
scale[head * step];
accum_ratio += top_diff_off[head * step] * top_off[head * step] /
scale_off[head * step];
++head;
}
// both add and subtract
while (head < channels) {
accum_ratio += top_diff[head * step] * top_data[head * step] /
scale[head * step];
accum_ratio += top_diff_off[head * step] * top_off[head * step] /
scale_off[head * step];
if (head - size >= 0) {
accum_ratio -= top_diff[(head - size) * step] *
top_data[(head - size) * step] / scale[(head - size) * step];
accum_ratio -= top_diff_off[(head - size) * step] *
top_off[(head - size) * step] / scale_off[(head - size) * step];
}
bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
* pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
bottom_diff_off[(head - post_pad) * step] =
top_diff_off[(head - post_pad) * step]
* pow(scale_off[(head - post_pad) * step], negative_beta)
- cache_ratio * bottom_off[(head - post_pad) * step] * accum_ratio;
++head;
}
// subtract only
while (head < channels + post_pad) {
if (head - size >= 0) {
accum_ratio -= top_diff[(head - size) * step] *
top_data[(head - size) * step] / scale[(head - size) * step];
accum_ratio -= top_diff_off[(head - size) * step] *
top_off[(head - size) * step] / scale_off[(head - size) * step];
}
bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
* pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
bottom_diff_off[(head - post_pad) * step] =
top_diff_off[(head - post_pad) * step]
* pow(scale_off[(head - post_pad) * step], negative_beta)
- cache_ratio * bottom_off[(head - post_pad) * step] * accum_ratio;
++head;
}
}
......
......@@ -9,31 +9,32 @@
namespace caffe {
template <typename Dtype>
__global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, Dtype* top_data,
int* mask, Dtype* top_mask) {
__global__ void MaxPoolForward(const int nthreads,
const Dtype* const bottom_data, const int num, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
Dtype* const top_data, int* mask, Dtype* top_mask) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const int pw = index % pooled_width;
const int ph = (index / pooled_width) % pooled_height;
const int c = (index / pooled_width / pooled_height) % channels;
const int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + kernel_h, height);
int wend = min(wstart + kernel_w, width);
const int hend = min(hstart + kernel_h, height);
const int wend = min(wstart + kernel_w, width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
Dtype maxval = -FLT_MAX;
int maxidx = -1;
bottom_data += (n * channels + c) * height * width;
const Dtype* const bottom_slice =
bottom_data + (n * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (bottom_data[h * width + w] > maxval) {
if (bottom_slice[h * width + w] > maxval) {
maxidx = h * width + w;
maxval = bottom_data[maxidx];
maxval = bottom_slice[maxidx];
}
}
}
......@@ -47,30 +48,32 @@ __global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
}
template <typename Dtype>
__global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, Dtype* top_data) {
__global__ void AvePoolForward(const int nthreads,
const Dtype* const bottom_data, const int num, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
Dtype* const top_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const int pw = index % pooled_width;
const int ph = (index / pooled_width) % pooled_height;
const int c = (index / pooled_width / pooled_height) % channels;
const int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + kernel_h, height + pad_h);
int wend = min(wstart + kernel_w, width + pad_w);
int pool_size = (hend - hstart) * (wend - wstart);
const int pool_size = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
hend = min(hend, height);
wend = min(wend, width);
Dtype aveval = 0;
bottom_data += (n * channels + c) * height * width;
const Dtype* const bottom_slice =
bottom_data + (n * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
aveval += bottom_data[h * width + w];
aveval += bottom_slice[h * width + w];
}
}
top_data[index] = aveval / pool_size;
......@@ -79,37 +82,38 @@ __global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data,
template <typename Dtype>
__global__ void StoPoolForwardTrain(const int nthreads,
const Dtype* bottom_data,
const Dtype* const bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, Dtype* rand_idx, Dtype* top_data) {
const int stride_w, Dtype* const rand_idx, Dtype* const top_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h;
int hend = min(hstart + kernel_h, height);
int wstart = pw * stride_w;
int wend = min(wstart + kernel_w, width);
const int pw = index % pooled_width;
const int ph = (index / pooled_width) % pooled_height;
const int c = (index / pooled_width / pooled_height) % channels;
const int n = index / pooled_width / pooled_height / channels;
const int hstart = ph * stride_h;
const int hend = min(hstart + kernel_h, height);
const int wstart = pw * stride_w;
const int wend = min(wstart + kernel_w, width);
Dtype cumsum = 0.;
bottom_data += (n * channels + c) * height * width;
const Dtype* const bottom_slice =
bottom_data + (n * channels + c) * height * width;
// First pass: get sum
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
cumsum += bottom_data[h * width + w];
cumsum += bottom_slice[h * width + w];
}
}
float thres = rand_idx[index] * cumsum;
const float thres = rand_idx[index] * cumsum;
// Second pass: get value, and set index.
cumsum = 0;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
cumsum += bottom_data[h * width + w];
cumsum += bottom_slice[h * width + w];
if (cumsum >= thres) {
rand_idx[index] = ((n * channels + c) * height + h) * width + w;
top_data[index] = bottom_data[h * width + w];
top_data[index] = bottom_slice[h * width + w];
return;
}
}
......@@ -120,29 +124,30 @@ __global__ void StoPoolForwardTrain(const int nthreads,
template <typename Dtype>
__global__ void StoPoolForwardTest(const int nthreads,
const Dtype* bottom_data,
const Dtype* const bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, Dtype* top_data) {
const int stride_w, Dtype* const top_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h;
int hend = min(hstart + kernel_h, height);
int wstart = pw * stride_w;
int wend = min(wstart + kernel_w, width);
const int pw = index % pooled_width;
const int ph = (index / pooled_width) % pooled_height;
const int c = (index / pooled_width / pooled_height) % channels;
const int n = index / pooled_width / pooled_height / channels;
const int hstart = ph * stride_h;
const int hend = min(hstart + kernel_h, height);
const int wstart = pw * stride_w;
const int wend = min(wstart + kernel_w, width);
// We set cumsum to be 0 to avoid divide-by-zero problems
Dtype cumsum = FLT_MIN;
Dtype cumvalues = 0.;
bottom_data += (n * channels + c) * height * width;
const Dtype* const bottom_slice =
bottom_data + (n * channels + c) * height * width;
// First pass: get sum
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
cumsum += bottom_data[h * width + w];
cumvalues += bottom_data[h * width + w] * bottom_data[h * width + w];
cumsum += bottom_slice[h * width + w];
cumvalues += bottom_slice[h * width + w] * bottom_slice[h * width + w];
}
}
top_data[index] = cumvalues / cumsum;
......@@ -210,43 +215,43 @@ void PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
template <typename Dtype>
__global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
const int* mask, const Dtype* top_mask, const int num, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
Dtype* bottom_diff) {
__global__ void MaxPoolBackward(const int nthreads, const Dtype* const top_diff,
const int* const mask, const Dtype* const top_mask, const int num,
const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, const int kernel_h,
const int kernel_w, const int stride_h, const int stride_w, const int pad_h,
const int pad_w, Dtype* const bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
// find out the local offset
int w = index % width;
int h = (index / width) % height;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
int phstart =
(h + pad_h < kernel_h) ? 0 : (h + pad_h - kernel_h) / stride_h + 1;
int phend = min((h + pad_h) / stride_h + 1, pooled_height);
int pwstart =
(w + pad_w < kernel_w) ? 0 : (w + pad_w - kernel_w) / stride_w + 1;
int pwend = min((w + pad_w) / stride_w + 1, pooled_width);
const int w = index % width;
const int h = (index / width) % height;
const int c = (index / width / height) % channels;
const int n = index / width / height / channels;
const int phstart =
(h + pad_h < kernel_h) ? 0 : (h + pad_h - kernel_h) / stride_h + 1;
const int phend = min((h + pad_h) / stride_h + 1, pooled_height);
const int pwstart =
(w + pad_w < kernel_w) ? 0 : (w + pad_w - kernel_w) / stride_w + 1;
const int pwend = min((w + pad_w) / stride_w + 1, pooled_width);
Dtype gradient = 0;
int offset = (n * channels + c) * pooled_height * pooled_width;
top_diff += offset;
const int offset = (n * channels + c) * pooled_height * pooled_width;
const Dtype* const top_diff_slice = top_diff + offset;
if (mask) {
mask += offset;
const int* const mask_slice = mask + offset;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if (mask[ph * pooled_width + pw] == h * width + w) {
gradient += top_diff[ph * pooled_width + pw];
if (mask_slice[ph * pooled_width + pw] == h * width + w) {
gradient += top_diff_slice[ph * pooled_width + pw];
}
}
}
} else {
top_mask += offset;
const Dtype* const top_mask_slice = top_mask + offset;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if (top_mask[ph * pooled_width + pw] == h * width + w) {
gradient += top_diff[ph * pooled_width + pw];
if (top_mask_slice[ph * pooled_width + pw] == h * width + w) {
gradient += top_diff_slice[ph * pooled_width + pw];
}
}
}
......@@ -256,25 +261,26 @@ __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
}
template <typename Dtype>
__global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
__global__ void AvePoolBackward(const int nthreads, const Dtype* const top_diff,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
Dtype* bottom_diff) {
Dtype* const bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
// find out the local offset
int w = index % width + pad_w;
int h = (index / width) % height + pad_h;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
int phend = min(h / stride_h + 1, pooled_height);
int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
int pwend = min(w / stride_w + 1, pooled_width);
const int w = index % width + pad_w;
const int h = (index / width) % height + pad_h;
const int c = (index / width / height) % channels;
const int n = index / width / height / channels;
const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
const int phend = min(h / stride_h + 1, pooled_height);
const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
const int pwend = min(w / stride_w + 1, pooled_width);
Dtype gradient = 0;
top_diff += (n * channels + c) * pooled_height * pooled_width;
const Dtype* const top_diff_slice =
top_diff + (n * channels + c) * pooled_height * pooled_width;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
// figure out the pooling size
......@@ -283,7 +289,7 @@ __global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
int hend = min(hstart + kernel_h, height + pad_h);
int wend = min(wstart + kernel_w, width + pad_w);
int pool_size = (hend - hstart) * (wend - wstart);
gradient += top_diff[ph * pooled_width + pw] / pool_size;
gradient += top_diff_slice[ph * pooled_width + pw] / pool_size;
}
}
bottom_diff[index] = gradient;
......@@ -293,29 +299,31 @@ __global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
template <typename Dtype>
__global__ void StoPoolBackward(const int nthreads,
const Dtype* rand_idx, const Dtype* top_diff,
const Dtype* const rand_idx, const Dtype* const top_diff,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, Dtype* bottom_diff) {
const int stride_w, Dtype* const bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
// find out the local offset
int w = index % width;
int h = (index / width) % height;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
int phend = min(h / stride_h + 1, pooled_height);
int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
int pwend = min(w / stride_w + 1, pooled_width);
const int w = index % width;
const int h = (index / width) % height;
const int c = (index / width / height) % channels;
const int n = index / width / height / channels;
const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
const int phend = min(h / stride_h + 1, pooled_height);
const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
const int pwend = min(w / stride_w + 1, pooled_width);
Dtype gradient = 0;
rand_idx += (n * channels + c) * pooled_height * pooled_width;
top_diff += (n * channels + c) * pooled_height * pooled_width;
const Dtype* const rand_idx_slice =
rand_idx + (n * channels + c) * pooled_height * pooled_width;
const Dtype* const top_diff_slice =
top_diff + (n * channels + c) * pooled_height * pooled_width;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
gradient += top_diff[ph * pooled_width + pw] *
(index == static_cast<int>(rand_idx[ph * pooled_width + pw]));
gradient += top_diff_slice[ph * pooled_width + pw] *
(index == static_cast<int>(rand_idx_slice[ph * pooled_width + pw]));
}
}
bottom_diff[index] = gradient;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment