#ifndef _H_TRANSWALKERS
#define _H_TRANSWALKERS
#include <security_cdsa_utilities/AuthorizationWalkers.h>
#include "flippers.h"
#include "server.h"
#include <security_cdsa_utilities/context.h>
using LowLevelMemoryUtilities::increment;
using LowLevelMemoryUtilities::difference;
bool flipClient();
class CheckingReconstituteWalker {
private:
void check(void *addr, size_t size)
{
if (addr < mBase || increment(addr, size) > mLimit)
CssmError::throwMe(CSSM_ERRCODE_INVALID_POINTER);
}
public:
CheckingReconstituteWalker(void *ptr, void *base, size_t size, bool flip);
template <class T>
void operator () (T &obj, size_t size = sizeof(T))
{
check(increment(&obj, -mOffset), size);
if (mFlip)
Flippers::flip(obj);
}
template <class T>
void operator () (T * &addr, size_t size = sizeof(T))
{
DEBUGWALK("checkreconst:ptr");
if (addr) {
void *p = addr;
blob(p, size);
addr = reinterpret_cast<T *>(p);
if (mFlip)
Flippers::flip(*addr);
}
}
template <class T>
void blob(T * &addr, size_t size)
{
DEBUGWALK("checkreconst:blob");
if (addr) {
if (mFlip) {
secdebug("flippers", "flipping %s@%p", Debug::typeName(addr).c_str(), addr);
Flippers::flip(addr);
}
check(addr, size);
addr = increment<T>(addr, mOffset);
}
}
static const bool needsRelinking = true;
static const bool needsSize = false;
private:
void *mBase; void *mLimit; off_t mOffset; bool mFlip; };
template <class T>
void relocate(T *obj, T *base, size_t size)
{
if (obj) {
if (base == NULL) CssmError::throwMe(CSSM_ERRCODE_INVALID_POINTER);
CheckingReconstituteWalker relocator(obj, base, size,
Server::process().byteFlipped());
walk(relocator, base);
}
}
void relocate(Context &context, void *base, Context::Attr *attrs, uint32 attrSize);
class FlipWalker {
private:
struct Base {
virtual ~Base() { }
virtual void flip() const = 0;
};
template <class T>
struct FlipRef : public Base {
T &obj;
FlipRef(T &s) : obj(s) { }
void flip() const { Flippers::flip(obj); }
};
template <class T>
struct FlipPtr : public Base {
T * &obj;
FlipPtr(T * &s) : obj(s) { }
void flip() const { Flippers::flip(*obj); Flippers::flip(obj); }
};
template <class T>
struct FlipBlob : public Base {
T * &obj;
FlipBlob(T * &s) : obj(s) { }
void flip() const { Flippers::flip(obj); }
};
struct Flipper {
Base *impl;
Flipper(Base *p) : impl(p) { }
bool operator < (const Flipper &other) const
{ return impl < other.impl; }
};
public:
~FlipWalker();
void doFlips(bool active = true);
template <class T>
void operator () (T &obj, size_t = sizeof(T))
{ mFlips.insert(new FlipRef<T>(obj)); }
template <class T>
T *operator () (T * &addr, size_t size = sizeof(T))
{ mFlips.insert(new FlipPtr<T>(addr)); return addr; }
template <class T>
void blob(T * &addr, size_t size)
{ mFlips.insert(new FlipBlob<T>(addr)); }
static const bool needsRelinking = true;
static const bool needsSize = true;
private:
set<Flipper> mFlips;
};
template <class T>
void flip(T &addr)
{
if (flipClient()) {
secdebug("flippers", "raw flipping %s", Debug::typeName(addr).c_str());
Flippers::flip(addr);
}
}
template <class T>
void flips(T *value, T ** &addr, T ** &base)
{
*addr = *base = value;
if (flipClient()) {
FlipWalker w; walk(w, value); w.doFlips(); Flippers::flip(*base); }
}
template <class BlobType>
const BlobType *makeBlob(const CssmData &blobData, CSSM_RETURN error = CSSM_ERRCODE_INVALID_DATA)
{
if (!blobData.data() || blobData.length() < sizeof(BlobType))
CssmError::throwMe(error);
const BlobType *blob = static_cast<const BlobType *>(blobData.data());
if (blob->totalLength != blobData.length())
CssmError::throwMe(error);
return blob;
}
class OutputData : public CssmData {
public:
OutputData(void **outP, mach_msg_type_number_t *outLength)
: mData(*outP), mLength(*outLength) { }
~OutputData()
{ mData = data(); mLength = length(); Server::releaseWhenDone(mData); }
void operator = (const CssmData &source)
{ CssmData::operator = (source); }
private:
void * &mData;
mach_msg_type_number_t &mLength;
};
Database *pickDb(Database *db1, Database *db2);
static inline Database *dbOf(Key *key) { return key ? &key->database() : NULL; }
inline Database *pickDb(Key *k1, Key *k2) { return pickDb(dbOf(k1), dbOf(k2)); }
inline Database *pickDb(Database *db1, Key *k2) { return pickDb(db1, dbOf(k2)); }
inline Database *pickDb(Key *k1, Database *db2) { return pickDb(dbOf(k1), db2); }
#endif //_H_TRANSWALKERS