001    /**
002     * Copyright 2004-2012 The Kuali Foundation
003     *
004     * Licensed under the Educational Community License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.opensource.org/licenses/ecl2.php
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    package org.kuali.hr.time.util;
017    
018    import java.io.BufferedReader;
019    import java.io.FileNotFoundException;
020    import java.io.FileReader;
021    import java.io.IOException;
022    import java.sql.Connection;
023    import java.sql.ResultSet;
024    import java.sql.SQLException;
025    import java.sql.Statement;
026    import java.util.ArrayList;
027    import java.util.List;
028    
029    import javax.sql.DataSource;
030    
031    import junit.framework.Assert;
032    
033    import org.apache.commons.lang.StringUtils;
034    import org.apache.log4j.Logger;
035    import org.enhydra.jdbc.pool.StandardXAPoolDataSource;
036    import org.kuali.hr.time.service.base.TkServiceLocator;
037    import org.kuali.rice.core.api.lifecycle.BaseLifecycle;
038    import org.springframework.jdbc.core.ConnectionCallback;
039    import org.springframework.jdbc.core.JdbcTemplate;
040    import org.springframework.jdbc.core.StatementCallback;
041    import org.springframework.transaction.PlatformTransactionManager;
042    import org.springframework.transaction.TransactionStatus;
043    import org.springframework.transaction.support.TransactionCallback;
044    import org.springframework.transaction.support.TransactionTemplate;
045    
046    public class SQLDataLifeCycle  extends BaseLifecycle {
047            protected static final Logger LOG = Logger.getLogger(SQLDataLifeCycle.class);
048    
049            public static final String TEST_TABLE_NAME = "KR_UNITTEST_T";
050        Class callingTestClass = null;
051    
052        public SQLDataLifeCycle() {
053    
054        }
055    
056        public SQLDataLifeCycle(Class caller) {
057            this.callingTestClass = caller;
058        }
059    
060        public void start() throws Exception {
061            final StandardXAPoolDataSource dataSource = (StandardXAPoolDataSource) TkServiceLocator.CONTEXT.getBean("kpmeDataSource");
062            final PlatformTransactionManager transactionManager = (PlatformTransactionManager) TkServiceLocator.CONTEXT.getBean("transactionManager");
063            final String schemaName = dataSource.getUser().toUpperCase();
064            loadData(transactionManager, dataSource, schemaName);
065            super.start();
066        }
067    
068            public void loadData(final PlatformTransactionManager transactionManager, final DataSource dataSource, final String schemaName) {
069                    LOG.info("Clearing tables for schema " + schemaName);
070                    Assert.assertNotNull("DataSource could not be located.", dataSource);
071    
072                    if (schemaName == null || schemaName.equals("")) {
073                            Assert.fail("Empty schema name given");
074                    }
075                    new TransactionTemplate(transactionManager).execute(new TransactionCallback<Object>() {
076                public Object doInTransaction(final TransactionStatus status) {
077                    verifyTestEnvironment(dataSource);
078                    return new JdbcTemplate(dataSource).execute(new StatementCallback<Object>() {
079                            public Object doInStatement(Statement statement) throws SQLException {
080                            if (callingTestClass != null) {
081                                     List<String> sqlStatements = getTestDataSQLStatements("src/test/config/sql/" + callingTestClass.getSimpleName() + ".sql");
082                            
083                                    for(String sql : sqlStatements){
084                                    if (!sql.startsWith("#") && !sql.startsWith("//") && !StringUtils.isEmpty(sql.trim())) {
085                                        // ignore comment lines in our sql reader.
086                                                statement.addBatch(sql);
087                                    }
088                                    }
089                            }
090                                    statement.executeBatch();
091                                    return null;
092                            }
093                    });
094                }
095            });
096            }
097    
098            void verifyTestEnvironment(final DataSource dataSource) {
099                    Assert.assertTrue("No table named '" + TEST_TABLE_NAME + "' was found in the configured database.  " + "You are attempting to run tests against a non-test database!!!",
100                    isTestTableInSchema(dataSource));
101            }
102    
103            Boolean isTestTableInSchema(final DataSource dataSource) {
104                Assert.assertNotNull("DataSource could not be located.", dataSource);
105                return (Boolean) new JdbcTemplate(dataSource).execute(new ConnectionCallback() {
106                            public Object doInConnection(final Connection connection) throws SQLException {
107                                    final ResultSet resultSet = connection.getMetaData().getTables(null, null, TEST_TABLE_NAME, null);
108                                    return new Boolean(resultSet.next());
109                            }
110                    });
111            }
112    
113            List<String> getTestDataSQLStatements(String fname){
114                    List<String> testDataSqlStatements = new ArrayList<String>();
115                    try {
116                            BufferedReader in = new BufferedReader(new FileReader(fname));
117                            String str;
118                            while ((str = in.readLine()) != null) {
119                                    testDataSqlStatements.add(str);
120                            }
121                    } catch (FileNotFoundException e) {
122                            LOG.warn("No file found for " + fname);
123                    } catch (IOException e) {
124                            LOG.error("IO exception in test data loading");
125                    }
126                    return testDataSqlStatements;
127            }
128    
129    }