■
Y = WX+B の式の意味は次のところ(51/187)で。
この式を高速で計算するC++AMPでの行列演算をつくってみた。(CUDAがこの分野での業界標準で、C++AMPの今後はどうよという気もするけど。)
#include "stdafx.h"
#include
#includeusing namespace concurrency;
class Vector
{
friend class Matrix;
public :
Vector():cnt_(0){};
Vector( float* v, int cnt ){ set(v,cnt);}
void set( float* v, int cnt )
{
cnt_ = cnt;
val_ = std::shared_ptr( new float[cnt], std::default_delete () ); if ( v == nullptr )
{
for (int i = 0; i < cnt; i++)
val_.get()[i] = 0;
}
else
{
for( int i = 0; i < cnt; i++ )
val_.get()[i] = v[i];
}}
Vector& operator+=( const Vector& v )
{
_ASSERT( cnt_ == v.cnt_ );
_ASSERT( this != &v );
for (int i = 0; i < cnt_; i++)
val_.get()[i] += v.val_.get()[i];return *this;
}float* get(){ return val_.get(); }
int count() const{ return cnt_; }float& operator[](int idx )
{
return get()[idx];
}private :
std::shared_ptrval_;
int cnt_;
};
class Matrix
{
friend class Vector;
public :
Matrix():col_ (0),row_(0){}
Matrix(int row, int col){ init(row,col); }void init( int row, int col )
{
col_ = col;
row_ = row;
int cnt = col*row;
w_ = std::shared_ptr(new float[cnt], std::default_delete ()); Zero();
}
void Zero()
{
memset( w_.get(), 0, sizeof(float)*row_*col_ );
}Vector Multi( Vector& x)
{
_ASSERT(row_ > 0 && col_ > 0);
_ASSERT(row_ == x.count());Vector r(nullptr, x.count());
array_view
ww( row_, col_, w_.get());
array_viewxx( x.count(), 1, x.get());
array_viewrr( r.count(), 1, r.get()); int cnt = col_;
parallel_for_each(
ww.extent,
[=](index<2> idx) restrict(amp)
{
int row = idx[0];
int col = idx[1];
for (int i = 0; i < cnt; i++)
{
rr[idx] += ww(row, i) * xx(i, col);
}
}
);ww.synchronize();
return r;
}
Vector WXplusB(Vector& x, Vector& b)
{
_ASSERT(row_ > 0 && col_ > 0);
_ASSERT(row_ == x.count());Vector r(nullptr, x.count());
array_view
ww(row_, col_, w_.get());
array_viewxx(x.count(), 1, x.get());
array_viewrr(r.count(), 1, r.get());
array_viewbb(r.count(), 1, b.get()); int cnt = col_;
parallel_for_each(
ww.extent,
[=](index<2> idx) restrict(amp)
{
int row = idx[0];
int col = idx[1];
for (int i = 0; i < cnt; i++)
{
rr[idx] += ww(row, i) * xx(i, col);
}
rr[idx] += bb[idx];
}
);ww.synchronize();
return r;
}float* get() { return w_.get(); }
float& w( int row, int col ){ return w_.get()[row*col_ + col]; }
private :
std::shared_ptrw_;
int col_,row_;
};
int main()
{
// 3行3列 重み行列
Matrix m(3,3);
m.w(0, 0) = 2;
m.w(0, 1) = -3;
m.w(0, 2) = 4;m.w(1, 0) = -4;
m.w(1, 1) = 1;
m.w(1, 2) = -5;m.w(2, 0) = -4;
m.w(2, 1) = -2;
m.w(2, 2) = 5;// input
float v1[] = {1,0,1};
Vector in( v1, _countof(v1) );// bias
float b1 [] = { -4,5,2 };
Vector b(b1, _countof(b1));// Calc with C++ amp.
Vector Y = m.WXplusB( in, b );float a1 = Y[0]; // 2
float a2 = Y[1]; //-4
float a3 = Y[2]; //3
return 0;
}