Fix CertficateDB locking scheme Currently we are locking every file going to be accessed by CertificateDB code even if it is not realy needed, because of a more general lock. This patch: - Replace the old FileLocker class with the pair Lock/Locker classes - Remove most of the locks in CertificateDB with only two locks one for main database locking and one lock for the file contain the current serial number. This is a Measurement Factory project === modified file 'src/ssl/certificate_db.cc' --- src/ssl/certificate_db.cc 2011-09-15 16:34:52 +0000 +++ src/ssl/certificate_db.cc 2011-09-22 09:25:48 +0000 @@ -1,69 +1,125 @@ /* * $Id$ */ #include "config.h" #include "ssl/certificate_db.h" +#if HAVE_ERRNO_H +#include +#endif #if HAVE_FSTREAM #include #endif #if HAVE_STDEXCEPT #include #endif #if HAVE_SYS_STAT_H #include #endif #if HAVE_SYS_FILE_H #include #endif #if HAVE_FCNTL_H #include #endif -Ssl::FileLocker::FileLocker(std::string const & filename) - : fd(-1) +#define HERE "(ssl_crtd) " << __FILE__ << ':' << __LINE__ << ": " + +Ssl::Lock::Lock(std::string const &aFilename) : + filename(aFilename), +#if _SQUID_MSWIN_ + hFile(INVALID_HANDLE_VALUE) +#else + fd(-1) +#endif +{ +} + +bool Ssl::Lock::locked() const { #if _SQUID_MSWIN_ + return hFile != INVALID_HANDLE_VALUE; +#else + return fd != -1; +#endif +} + +void Ssl::Lock::lock() +{ + +#if _SQUID_MSWIN_ hFile = CreateFile(TEXT(filename.c_str()), GENERIC_READ, 0, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL); - if (hFile != INVALID_HANDLE_VALUE) - LockFile(hFile, 0, 0, 1, 0); + if (hFile == INVALID_HANDLE_VALUE) #else fd = open(filename.c_str(), 0); - if (fd != -1) - flock(fd, LOCK_EX); + if (fd == -1) #endif + throw std::runtime_error("Failed to open file " + filename); + + +#if _SQUID_MSWIN_ + if (!LockFile(hFile, 0, 0, 1, 0)) +#else + if (flock(fd, LOCK_EX) != 0) +#endif + throw std::runtime_error("Failed to get a lock of " + filename); } -Ssl::FileLocker::~FileLocker() -{ +void Ssl::Lock::unlock() +{ #if _SQUID_MSWIN_ if (hFile != INVALID_HANDLE_VALUE) { UnlockFile(hFile, 0, 0, 1, 0); CloseHandle(hFile); + hFile = INVALID_HANDLE_VALUE; } #else if (fd != -1) { flock(fd, LOCK_UN); close(fd); + fd = -1; } #endif + else + throw std::runtime_error("Lock is already unlocked for " + filename); +} + +Ssl::Lock::~Lock() +{ + if (locked()) + unlock(); +} + +Ssl::Locker::Locker(Lock &aLock, const char *aFileName, int aLineNo): + weLocked(false), lock(aLock), fileName(aFileName), lineNo(aLineNo) +{ + if (!lock.locked()) { + lock.lock(); + weLocked = true; + } +} + +Ssl::Locker::~Locker() +{ + if (weLocked) + lock.unlock(); } Ssl::CertificateDb::Row::Row() : width(cnlNumber) { row = new char *[width + 1]; for (size_t i = 0; i < width + 1; i++) row[i] = NULL; } Ssl::CertificateDb::Row::~Row() { if (row) { for (size_t i = 0; i < width + 1; i++) { delete[](row[i]); } delete[](row); } } @@ -113,60 +169,60 @@ int Ssl::CertificateDb::index_name_cmp(const char **a, const char **b) { return(strcmp(a[Ssl::CertificateDb::cnlName], b[CertificateDb::cnlName])); } const std::string Ssl::CertificateDb::serial_file("serial"); const std::string Ssl::CertificateDb::db_file("index.txt"); const std::string Ssl::CertificateDb::cert_dir("certs"); const std::string Ssl::CertificateDb::size_file("size"); const size_t Ssl::CertificateDb::min_db_size(4096); Ssl::CertificateDb::CertificateDb(std::string const & aDb_path, size_t aMax_db_size, size_t aFs_block_size) : db_path(aDb_path), serial_full(aDb_path + "/" + serial_file), db_full(aDb_path + "/" + db_file), cert_full(aDb_path + "/" + cert_dir), size_full(aDb_path + "/" + size_file), db(NULL), max_db_size(aMax_db_size), fs_block_size(aFs_block_size), + dbLock(db_full), + dbSerialLock(serial_full), enabled_disk_store(true) { if (db_path.empty() && !max_db_size) enabled_disk_store = false; else if ((db_path.empty() && max_db_size) || (!db_path.empty() && !max_db_size)) throw std::runtime_error("ssl_crtd is missing the required parameter. There should be -s and -M parameters together."); - else - load(); } bool Ssl::CertificateDb::find(std::string const & host_name, Ssl::X509_Pointer & cert, Ssl::EVP_PKEY_Pointer & pkey) { - FileLocker db_locker(db_full); + const Locker locker(dbLock, Here); load(); return pure_find(host_name, cert, pkey); } bool Ssl::CertificateDb::addCertAndPrivateKey(Ssl::X509_Pointer & cert, Ssl::EVP_PKEY_Pointer & pkey) { - FileLocker db_locker(db_full); + const Locker locker(dbLock, Here); load(); if (!db || !cert || !pkey || min_db_size > max_db_size) return false; Row row; ASN1_INTEGER * ai = X509_get_serialNumber(cert.get()); std::string serial_string; Ssl::BIGNUM_Pointer serial(ASN1_INTEGER_to_BN(ai, NULL)); { TidyPointer hex_bn(BN_bn2hex(serial.get())); serial_string = std::string(hex_bn.get()); } row.setValue(cnlSerial, serial_string.c_str()); char ** rrow = TXT_DB_get_by_index(db.get(), cnlSerial, row.getRow()); if (rrow != NULL) return false; { TidyPointer subject(X509_NAME_oneline(X509_get_subject_name(cert.get()), NULL, 0)); if (pure_find(subject.get(), cert, pkey)) return true; @@ -178,52 +234,51 @@ } while (max_db_size < size()) { deleteOldestCertificate(); } row.setValue(cnlType, "V"); ASN1_UTCTIME * tm = X509_get_notAfter(cert.get()); row.setValue(cnlExp_date, std::string(reinterpret_cast(tm->data), tm->length).c_str()); row.setValue(cnlFile, "unknown"); { TidyPointer subject(X509_NAME_oneline(X509_get_subject_name(cert.get()), NULL, 0)); row.setValue(cnlName, subject.get()); } if (!TXT_DB_insert(db.get(), row.getRow())) return false; row.reset(); std::string filename(cert_full + "/" + serial_string + ".pem"); - FileLocker cert_locker(filename); if (!writeCertAndPrivateKeyToFile(cert, pkey, filename.c_str())) return false; addSize(filename); save(); return true; } BIGNUM * Ssl::CertificateDb::getCurrentSerialNumber() { - FileLocker serial_locker(serial_full); + const Locker locker(dbSerialLock, Here); // load serial number from file. Ssl::BIO_Pointer file(BIO_new(BIO_s_file())); if (!file) return NULL; if (BIO_rw_filename(file.get(), const_cast(serial_full.c_str())) <= 0) return NULL; Ssl::ASN1_INT_Pointer serial_ai(ASN1_INTEGER_new()); if (!serial_ai) return NULL; char buffer[1024]; if (!a2i_ASN1_INTEGER(file.get(), serial_ai.get(), buffer, sizeof(buffer))) return NULL; Ssl::BIGNUM_Pointer serial(ASN1_INTEGER_to_BN(serial_ai.get(), NULL)); if (!serial) return NULL; @@ -280,94 +335,91 @@ throw std::runtime_error("SSL error"); if (BIO_write_filename(file.get(), const_cast(serial_full.c_str())) <= 0) throw std::runtime_error("Cannot open " + cert_full + " to open"); i2a_ASN1_INTEGER(file.get(), i.get()); std::ofstream size(size_full.c_str()); if (size) size << 0; else throw std::runtime_error("Cannot open " + size_full + " to open"); std::ofstream db(db_full.c_str()); if (!db) throw std::runtime_error("Cannot open " + db_full + " to open"); } void Ssl::CertificateDb::check(std::string const & db_path, size_t max_db_size) { CertificateDb db(db_path, max_db_size, 0); + db.load(); } std::string Ssl::CertificateDb::getSNString() const { - FileLocker serial_locker(serial_full); + const Locker locker(dbSerialLock, Here); std::ifstream file(serial_full.c_str()); if (!file) return ""; std::string serial; file >> serial; return serial; } bool Ssl::CertificateDb::pure_find(std::string const & host_name, Ssl::X509_Pointer & cert, Ssl::EVP_PKEY_Pointer & pkey) { if (!db) return false; Row row; row.setValue(cnlName, host_name.c_str()); char **rrow = TXT_DB_get_by_index(db.get(), cnlName, row.getRow()); if (rrow == NULL) return false; if (!sslDateIsInTheFuture(rrow[cnlExp_date])) { deleteByHostname(rrow[cnlName]); return false; } // read cert and pkey from file. std::string filename(cert_full + "/" + rrow[cnlSerial] + ".pem"); - FileLocker cert_locker(filename); readCertAndPrivateKeyFromFiles(cert, pkey, filename.c_str(), NULL); if (!cert || !pkey) return false; return true; } size_t Ssl::CertificateDb::size() const { - FileLocker size_locker(size_full); return readSize(); } void Ssl::CertificateDb::addSize(std::string const & filename) { - FileLocker size_locker(size_full); writeSize(readSize() + getFileSize(filename)); } void Ssl::CertificateDb::subSize(std::string const & filename) { - FileLocker size_locker(size_full); writeSize(readSize() - getFileSize(filename)); } size_t Ssl::CertificateDb::readSize() const { size_t db_size; std::ifstream size_file(size_full.c_str()); if (!size_file && enabled_disk_store) throw std::runtime_error("cannot read \"" + size_full + "\" file"); size_file >> db_size; return db_size; } void Ssl::CertificateDb::writeSize(size_t db_size) { std::ofstream size_file(size_full.c_str()); if (!size_file && enabled_disk_store) throw std::runtime_error("cannot write \"" + size_full + "\" file"); size_file << db_size; } @@ -415,41 +467,40 @@ void Ssl::CertificateDb::save() { if (!db) throw std::runtime_error("The certificates database is not loaded");; // To save the db to file, create a new BIO with BIO file methods. Ssl::BIO_Pointer out(BIO_new(BIO_s_file())); if (!out || !BIO_write_filename(out.get(), const_cast(db_full.c_str()))) throw std::runtime_error("Failed to initialize " + db_full + " file for writing");; if (TXT_DB_write(out.get(), db.get()) < 0) throw std::runtime_error("Failed to write " + db_full + " file"); } // Normally defined in defines.h file #define countof(arr) (sizeof(arr)/sizeof(*arr)) void Ssl::CertificateDb::deleteRow(const char **row, int rowIndex) { const std::string filename(cert_full + "/" + row[cnlSerial] + ".pem"); - const FileLocker cert_locker(filename); #if OPENSSL_VERSION_NUMBER >= 0x1000004fL sk_OPENSSL_PSTRING_delete(db.get()->data, rowIndex); #else sk_delete(db.get()->data, rowIndex); #endif const Columns db_indexes[]={cnlSerial, cnlName}; for (unsigned int i = 0; i < countof(db_indexes); i++) { #if OPENSSL_VERSION_NUMBER >= 0x1000004fL if (LHASH_OF(OPENSSL_STRING) *fieldIndex = db.get()->index[db_indexes[i]]) lh_OPENSSL_STRING_delete(fieldIndex, (char **)row); #else if (LHASH *fieldIndex = db.get()->index[db_indexes[i]]) lh_delete(fieldIndex, row); #endif } subSize(filename); int ret = remove(filename.c_str()); if (ret < 0) === modified file 'src/ssl/certificate_db.h' --- src/ssl/certificate_db.h 2011-09-15 16:34:52 +0000 +++ src/ssl/certificate_db.h 2011-09-22 09:27:14 +0000 @@ -1,54 +1,74 @@ /* * $Id$ */ #ifndef SQUID_SSL_CERTIFICATE_DB_H #define SQUID_SSL_CERTIFICATE_DB_H #include "ssl/gadgets.h" #include "ssl/support.h" #if HAVE_STRING #include #endif #if HAVE_OPENSSL_OPENSSLV_H #include #endif namespace Ssl { -/// Cross platform file locker. -class FileLocker -{ +/// maintains an exclusive blocking file-based lock +class Lock { public: - /// Lock file - FileLocker(std::string const & aFilename); - /// Unlock file - ~FileLocker(); + explicit Lock(std::string const &filename); ///< creates an unlocked lock + ~Lock(); ///< releases the lock if it is locked + void lock(); ///< locks the lock, may block + void unlock(); ///< unlocks locked lock or throws + bool locked() const; ///< whether our lock is locked + const char *name() const { return filename.c_str(); } private: + std::string filename; #if _SQUID_MSWIN_ HANDLE hFile; ///< Windows file handle. #else int fd; ///< Linux file descriptor. #endif }; +/// an exception-safe way to obtain and release a lock +class Locker +{ +public: + /// locks the lock if the lock was unlocked + Locker(Lock &lock, const char *aFileName, int lineNo); + /// unlocks the lock if it was locked by us + ~Locker(); +private: + bool weLocked; ///< whether we locked the lock + Lock &lock; ///< the lock we are operating on + const std::string fileName; ///< where the lock was needed + const int lineNo; ///< where the lock was needed +}; + +/// convenience macro to pass source code location to Locker and others +#define Here __FILE__, __LINE__ + /** * Database class for storing SSL certificates and their private keys. * A database consist by: * - A disk file to store current serial number * - A disk file to store the current database size * - A disk file which is a normal TXT_DB openSSL database * - A directory under which the certificates and their private keys stored. * The database before used must initialized with CertificateDb::create static method. */ class CertificateDb { public: /// Names of db columns. enum Columns { cnlType = 0, cnlExp_date, cnlRev_date, cnlSerial, cnlFile, cnlName, @@ -133,26 +153,28 @@ static IMPLEMENT_LHASH_HASH_FN(index_name_hash,const char **) static IMPLEMENT_LHASH_COMP_FN(index_name_cmp,const char **) #endif static const std::string serial_file; ///< Base name of the file to store serial number. static const std::string db_file; ///< Base name of the database index file. static const std::string cert_dir; ///< Base name of the directory to store the certs. static const std::string size_file; ///< Base name of the file to store db size. /// Min size of disk db. If real size < min_db_size the db will be disabled. static const size_t min_db_size; const std::string db_path; ///< The database directory. const std::string serial_full; ///< Full path of the file to store serial number. const std::string db_full; ///< Full path of the database index file. const std::string cert_full; ///< Full path of the directory to store the certs. const std::string size_full; ///< Full path of the file to store the db size. TXT_DB_Pointer db; ///< Database with certificates info. const size_t max_db_size; ///< Max size of db. const size_t fs_block_size; ///< File system block size. + mutable Lock dbLock; ///< protects the database file + mutable Lock dbSerialLock; ///< protects the serial number file bool enabled_disk_store; ///< The storage on the disk is enabled. }; } // namespace Ssl #endif // SQUID_SSL_CERTIFICATE_DB_H