p0_starter.h
#pragma once
#include <memory>
#include <stdexcept>
#include <vector>
#include "common/exception.h"
#include "common/logger.h"
namespace bustub {
template <typename T>
class Matrix {
protected:
Matrix(int rows, int cols) : rows_(rows), cols_(cols), linear_(nullptr) {
if (rows_ <= 0 || cols_ <= 0) {
rows_ = cols_ = -1;
return;
}
linear_ = new T[rows * cols];
}
int rows_;
int cols_;
T *linear_;
public:
virtual int GetRowCount() const = 0;
virtual int GetColumnCount() const = 0;
virtual T GetElement(int i, int j) const = 0;
virtual void SetElement(int i, int j, T val) = 0;
virtual void FillFrom(const std::vector<T> &source) = 0;
virtual ~Matrix() { delete[] linear_; }
};
template <typename T>
class RowMatrix : public Matrix<T> {
public:
RowMatrix(int rows, int cols) : Matrix<T>(rows, cols), data_(nullptr) {
if (this->rows_ <= 0 || this->cols_ <= 0) {
this->rows_ = -1;
this->cols_ = -1;
return;
}
data_ = new T *[this->rows_ * this->cols_];
for (int i = 0; i < this->rows_; i++) {
data_[i] = &this->linear_[i * this->cols_];
}
}
int GetRowCount() const override { return this->rows_; }
int GetColumnCount() const override { return this->cols_; }
T GetElement(int i, int j) const override {
if (i >= this->rows_ || j >= this->cols_ || i < 0 || j < 0) {
throw Exception(ExceptionType::OUT_OF_RANGE, "RowMatrix::GetElement(int,int) out of range:i="+i+",j="+j);
}
return data_[i][j];
}
void SetElement(int i, int j, T val) override {
if (i >= this->rows_ || j >= this->cols_ || i < 0 || j < 0) {
throw Exception(ExceptionType::OUT_OF_RANGE, "RowMatrix::GetElement(int,int) out of range");
}
data_[i][j] = val;
}
void FillFrom(const std::vector<T> &source) override {
if (static_cast<int>(source.size()) != this->rows_ * this->cols_) {
throw Exception(ExceptionType::OUT_OF_RANGE, "RowMatrix::FillFrom(vector) out of range");
}
for (int i = 0; i < this->rows_; i++) {
for (int j = 0; j < this->cols_; j++) {
SetElement(i, j, source[i * this->cols_ + j]);
}
}
}
~RowMatrix() override { delete[] data_; }
private:
T **data_;
};
template <typename T>
class RowMatrixOperations {
public:
static std::unique_ptr<RowMatrix<T>> Add(const RowMatrix<T> *matrixA, const RowMatrix<T> *matrixB) {
int RowA = matrixA->GetRowCount();
int RowB = matrixB->GetRowCount();
int ColA = matrixA->GetColumnCount();
int ColB = matrixB->GetColumnCount();
if (RowA != RowB || ColA != ColB || !matrixA || !matrixB) {
return std::unique_ptr<RowMatrix<T>>(nullptr);
}
std::unique_ptr<RowMatrix<T>> result = std::make_unique<RowMatrix<T>>(RowA, ColA);
T temp;
for (int i = 0; i < RowA; i++) {
for (int j = 0; j < ColA; j++) {
temp = matrixA->GetElement(i, j) + matrixB->GetElement(i, j);
result->SetElement(i, j, temp);
}
}
return result;
}
static std::unique_ptr<RowMatrix<T>> Multiply(const RowMatrix<T> *matrixA, const RowMatrix<T> *matrixB) {
int RowA = matrixA->GetRowCount();
int RowB = matrixB->GetRowCount();
int ColA = matrixA->GetColumnCount();
int ColB = matrixB->GetColumnCount();
if (ColA != RowB || !matrixA || !matrixB) {
return std::unique_ptr<RowMatrix<T>>(nullptr);
}
auto result = std::make_unique<RowMatrix<T>>(RowA,ColB);
T temp;
for (int row = 0; row < RowA; row++) {
for (int col = 0; col < ColB; col++) {
temp = matrixA->GetElement(row, 0) * matrixB->GetElement(0, col);
for (int eleIdx = 1; eleIdx < ColA; eleIdx++) {
temp += matrixA->GetElement(row, eleIdx) * matrixB->GetElement(eleIdx, col);
}
result->SetElement(row,col,temp);
}
}
return result;
}
static std::unique_ptr<RowMatrix<T>> GEMM(const RowMatrix<T> *matrixA, const RowMatrix<T> *matrixB,
const RowMatrix<T> *matrixC) {
return Add(matrixC,Multiply(matrixA,matrixB).get());
}
};
}
|