This is the fifth in a series of posts on how to build a LINQ IQueryable provider.  If you have not read the previous posts please take a look before proceeding, or if you are daring dig right in.

Complete list of posts in the Building an IQueryable Provider series 

Over the past four parts of this series I have constructed a working LINQ IQueryable provider that targets ADO and SQL and has so far been able to translate both Queryable.Where and Queryable.Select standard query operators. Yet, as big of an accomplishment that has been there are still a few gaping holes and I’m not talking about other missing operators like OrderBy and Join. I’m talking about huge conceptual gaffs that will bite anyone that strays from my oh-so-ideally crafted demo queries.

Fixing the Gaping Holes

Certainly, I can write a simple where/select pair and it works as advertised.  My select expression can be arbitrarily complex and it still chugs along.

var query = db.Customers.Where(c => c.City == city)

                        .Select(c => new {

                            Name = c.ContactName,

                            Location = c.City });

However, if just rearrange the order of Where and Select it all falls apart.

var query = db.Customers.Select(c => new {

                            Name = c.ContactName,

                            Location = c.City })
                        .Where(x => x.Location == city);

This handsome little query generates SQL that’s not exactly right.

 

SELECT * FROM (SELECT ContactName, City FROM (SELECT * FROM Customers) AS T) AS T WHERE (Location = 'London')

It also generates an exception when executed, “Invalid column name 'Location'.”  It seems my oversimplifying practice of treating member accesses as column references has backfired.  I naively assumed that the only member accesses in the sub-trees would match the names of the columns being generated. Yet, that’s obviously not true. So either I need to find a way to change the names of columns to match or I need to figure out some other way to deal with member accesses.

 

I suppose either is possible.  Yet, if I consider a slightly more complicated example, renaming the columns is not sufficient.  If the select expression produces a hierarchy of objects, then references to members can become a ‘multi-dot’ operation.

 

var query = db.Customers.Select(c => new {

                            Name = c.ContactName,

                            Location = new {
                                City = c.City,
                                Country = c.Country
                                }
                            })
                        .Where(x => x.Location.City == city);

Now, how am I going to translate this? The existing code does not even contain a concept of what this intermediate ‘Location’ object could be.  Luckily, I already realize what I need to do, yet it’s going to require a big change.  I need to step away from this notion that my provider is just translating query expressions into text. It’s translating query expressions into SQL. Text is only one possible manifestation of SQL and it’s not a very good one for programming logic to operate on. Of course, I’m going to need text eventually, but if I could first represent SQL as an abstraction, I could handle much more complicated translations.

 

Of course, the best data structure to operate on is a semantic SQL tree. So, ideally, I would have this entirely separate tree definition for SQL that I could translate LINQ query expressions into, but that would be a lot of work.  Luckily, the definition of this ideal SQL tree would overlap a lot with LINQ trees, so I’m going to cheat and simply teach LINQ expression trees about SQL. To do this, I’ll add some new expression node types.  It won’t matter if no other LINQ API understands them. I’ll just keep them to myself.

 

internal enum DbExpressionType {

    Table = 1000, // make sure these don't overlap with ExpressionType

    Column,

    Select,

    Projection

}

 

internal class TableExpression : Expression {

    string alias;

    string name;

    internal TableExpression(Type type, string alias, string name)

        : base((ExpressionType)DbExpressionType.Table, type) {

        this.alias = alias;

        this.name = name;

    }

    internal string Alias {

        get { return this.alias; }

    }

    internal string Name {

        get { return this.name; }

    }

}

 

internal class ColumnExpression : Expression {

    string alias;

    string name;

    int ordinal;

    internal ColumnExpression(Type type, string alias, string name, int ordinal)

        : base((ExpressionType)DbExpressionType.Column, type) {

        this.alias = alias;

        this.name = name;

        this.ordinal = ordinal;

    }

    internal string Alias {

        get { return this.alias; }

    }

    internal string Name {

        get { return this.name; }

    }

    internal int Ordinal {

        get { return this.ordinal; }

    }

}

 

internal class ColumnDeclaration {

    string name;

    Expression expression;

    internal ColumnDeclaration(string name, Expression expression) {

        this.name = name;

        this.expression = expression;

    }

    internal string Name {

        get { return this.name; }

    }

    internal Expression Expression {

        get { return this.expression; }

    }

}

 

internal class SelectExpression : Expression {

    string alias;

    ReadOnlyCollection<ColumnDeclaration> columns;

    Expression from;

    Expression where;

    internal SelectExpression(Type type, string alias, IEnumerable<ColumnDeclaration> columns, Expression from, Expression where)

        : base((ExpressionType)DbExpressionType.Select, type) {

        this.alias = alias;

        this.columns = columns as ReadOnlyCollection<ColumnDeclaration>;

        if (this.columns == null) {

            this.columns = new List<ColumnDeclaration>(columns).AsReadOnly();

        }

        this.from = from;

        this.where = where;

    }

    internal string Alias {

        get { return this.alias; }

    }

    internal ReadOnlyCollection<ColumnDeclaration> Columns {

        get { return this.columns; }

    }

    internal Expression From {

        get { return this.from; }

    }

    internal Expression Where {

        get { return this.where; }

    }

}

 

internal class ProjectionExpression : Expression {

    SelectExpression source;

    Expression projector;

    internal ProjectionExpression(SelectExpression source, Expression projector)

        : base((ExpressionType)DbExpressionType.Projection, projector.Type) {

        this.source = source;

        this.projector = projector;

    }

    internal SelectExpression Source {

        get { return this.source; }

    }

    internal Expression Projector {

        get { return this.projector; }

    }

}

 

The only bits of SQL I really need to add to LINQ expression trees are the concepts of a SQL Select query that produces one or more columns, a reference to a column, a reference to a table and a projection that reassembles objects out of column references.

 

I went ahead and defined my own DbExpressionType enum that ‘extends’ the base ExpressionType enum by picking a sufficiently large starting value to not collide with the other definitions.  If there was such a way as to derive from an enum I would have done that, but this will work as long as I am diligent.

 

Each of the new expression nodes follows the same pattern set by the LINQ expressions, being immutable, etc; except they are now modeling SQL concepts and not CLR concepts.  Notice, how the SelectExpression contains a collection of columns, and both a from and where expression.  That is because this expression node is meant to match what a legal SQL select statement would contain. 

 

The ProjectionExpression describes how to construct a result object out of the columns of a select expression.  If you think about it, this is almost exactly the same job that the projection expression held in Part IV, the one that was used to build the delegate for the ProjectionReader. Only this time, it’s possible to reason about projection in terms of the SQL query and not just as a function that assembles objects out of DataReaders.

 

Of course, now that I’ve got new nodes, I need a new visitor. The DbExpressionVisitor extends the ExpressionVisitor, adding the base visit pattern for the new nodes.

 

internal class DbExpressionVisitor : ExpressionVisitor {

    protected override Expression Visit(Expression exp) {

        if (exp == null) {

            return null;

        }

        switch ((DbExpressionType)exp.NodeType) {

            case DbExpressionType.Table:

                return this.VisitTable((TableExpression)exp);

            case DbExpressionType.Column:

                return this.VisitColumn((ColumnExpression)exp);

            case DbExpressionType.Select:

                return this.VisitSelect((SelectExpression)exp);

            case DbExpressionType.Projection:

                return this.VisitProjection((ProjectionExpression)exp);

            default:

                return base.Visit(exp);

        }

    }

    protected virtual Expression VisitTable(TableExpression table) {

        return table;

    }

    protected virtual Expression VisitColumn(ColumnExpression column) {

        return column;

    }

    protected virtual Expression VisitSelect(SelectExpression select) {

        Expression from = this.VisitSource(select.From);

        Expression where = this.Visit(select.Where);

        ReadOnlyCollection<ColumnDeclaration> columns = this.VisitColumnDeclarations(select.Columns);

        if (from != select.From || where != select.Where || columns != select.Columns) {

            return new SelectExpression(select.Type, select.Alias, columns, from, where);

        }

        return select;

    }

    protected virtual Expression VisitSource(Expression source) {

        return this.Visit(source);

    }

    protected virtual Expression VisitProjection(ProjectionExpression proj) {

        SelectExpression source = (SelectExpression)this.Visit(proj.Source);

        Expression projector = this.Visit(proj.Projector);

        if (source != proj.Source || projector != proj.Projector) {

            return new ProjectionExpression(source, projector);

        }

        return proj;

    }

    protected ReadOnlyCollection<ColumnDeclaration> VisitColumnDeclarations(ReadOnlyCollection<ColumnDeclaration> columns) {

        List<ColumnDeclaration> alternate = null;

        for (int i = 0, n = columns.Count; i < n; i++) {

            ColumnDeclaration column = columns[i];

            Expression e = this.Visit(column.Expression);

            if (alternate == null && e != column.Expression) {

                alternate = columns.Take(i).ToList();

            }

            if (alternate != null) {

                alternate.Add(new ColumnDeclaration(column.Name, e));

            }

        }

        if (alternate != null) {

            return alternate.AsReadOnly();

        }

        return columns;

    }

}

That’s better.  Now I feel like I’m really headed somewhere!

 

The next step is to take a stick of dynamite and blow up the QueryTranslator. No more monolithic expression tree to string translator.  What I need are individual pieces that handle separate tasks;  one to bind the expression tree by figuring out what methods like Queryable.Select mean and another to convert the resulting tree into SQL text.  Hopefully, by concocting this LINQ/SQL hybrid tree I’ll be able to figure out the member access mess.

 

Here’s the code for the new QueryBinder class.

 

internal class QueryBinder : ExpressionVisitor {

    ColumnProjector columnProjector;

    Dictionary<ParameterExpression, Expression> map;

    int aliasCount;

 

    internal QueryBinder() {

        this.columnProjector = new ColumnProjector(this.CanBeColumn);

    }

 

    private bool CanBeColumn(Expression expression) {

        return expression.NodeType == (ExpressionType)DbExpressionType.Column;

    }

 

    internal Expression Bind(Expression expression) {

        this.map = new Dictionary<ParameterExpression, Expression>();

        return this.Visit(expression);

    }

 

    private static Expression StripQuotes(Expression e) {

        while (e.NodeType == ExpressionType.Quote) {

            e = ((UnaryExpression)e).Operand;

        }

        return e;

    }

 

    private string GetNextAlias() {

        return "t" + (aliasCount++);

    }

 

    private ProjectedColumns ProjectColumns(Expression expression, string newAlias, string existingAlias) {

        return this.columnProjector.ProjectColumns(expression, newAlias, existingAlias);

    }

 

    protected override Expression VisitMethodCall(MethodCallExpression m) {

        if (m.Method.DeclaringType == typeof(Queryable) ||

            m.Method.DeclaringType == typeof(Enumerable)) {

            switch (m.Method.Name) {

                case "Where":

                    return this.BindWhere(m.Type, m.Arguments[0], (LambdaExpression)StripQuotes(m.Arguments[1]));

                case "Select":

                    return this.BindSelect(m.Type, m.Arguments[0], (LambdaExpression)StripQuotes(m.Arguments[1]));

            }

            throw new NotSupportedException(string.Format("The method '{0}' is not supported", m.Method.Name));

        }

        return base.VisitMethodCall(m);

    }

 

    private Expression BindWhere(Type resultType, Expression source, LambdaExpression predicate) {

        ProjectionExpression projection = (ProjectionExpression)this.Visit(source);

        this.map[predicate.Parameters[0]] = projection.Projector;

        Expression where = this.Visit(predicate.Body);

        string alias = this.GetNextAlias();

        ProjectedColumns pc = this.ProjectColumns(projection.Projector, alias, GetExistingAlias(projection.Source));

        return new ProjectionExpression(

            new SelectExpression(resultType, alias, pc.Columns, projection.Source, where),

            pc.Projector

            );

    }

 

    private Expression BindSelect(Type resultType, Expression source, LambdaExpression selector) {

        ProjectionExpression projection = (ProjectionExpression)this.Visit(source);

        this.map[selector.Parameters[0]] = projection.Projector;

        Expression expression = this.Visit(selector.Body);

        string alias = this.GetNextAlias();

        ProjectedColumns pc = this.ProjectColumns(expression, alias, GetExistingAlias(projection.Source));

        return new ProjectionExpression(

            new SelectExpression(resultType, alias, pc.Columns, projection.Source, null),

            pc.Projector

            );

    }

 

    private static string GetExistingAlias(Expression source) {

        switch ((DbExpressionType)source.NodeType) {

            case DbExpressionType.Select:

                return ((SelectExpression)source).Alias;

            case DbExpressionType.Table:

                return ((TableExpression)source).Alias;

            default:

                throw new InvalidOperationException(string.Format("Invalid source node type '{0}'", source.NodeType));

        }

    }

 

    private bool IsTable(object value) {

        IQueryable q = value as IQueryable;

        return q != null && q.Expression.NodeType == ExpressionType.Constant;

    }

 

    private string GetTableName(object table) {

        IQueryable tableQuery = (IQueryable)table;

        Type rowType = tableQuery.ElementType;

        return rowType.Name;

    }

 

    private string GetColumnName(MemberInfo member) {

        return member.Name;

    }

 

    private Type GetColumnType(MemberInfo member) {

        FieldInfo fi = member as FieldInfo;

        if (fi != null) {

            return fi.FieldType;

        }

        PropertyInfo pi = (PropertyInfo)member;

        return pi.PropertyType;

    }

 

    private IEnumerable<MemberInfo> GetMappedMembers(Type rowType) {

        return rowType.GetFields().Cast<MemberInfo>();

    }

 

    private ProjectionExpression GetTableProjection(object value) {

        IQueryable table = (IQueryable)value;

        string tableAlias = this.GetNextAlias();

        string selectAlias = this.GetNextAlias();

        List<MemberBinding> bindings = new List<MemberBinding>();

        List<ColumnDeclaration> columns = new List<ColumnDeclaration>();

        foreach (MemberInfo mi in this.GetMappedMembers(table.ElementType)) {

            string columnName = this.GetColumnName(mi);

            Type columnType = this.GetColumnType(mi);

            int ordinal = columns.Count;