Untitled
unknown
plain_text
a month ago
17 kB
4
Indexable
from airflow import DAG from airflow.utils.task_group import TaskGroup # Add this import from airflow.operators.python import PythonOperator from airflow.decorators import task from airflow.models.dag import DAG from airflow.models import Variable from airflow.exceptions import AirflowTaskTimeout from airflow.models.baseoperator import chain from datetime import datetime, timedelta import sys import os import logging import timeout_decorator import time import json import pandas as pd from concurrent.futures import ThreadPoolExecutor import random sys.path.extend( [ "/app/scraper/src/fetchers/reference_data/tradingview", "/app/scraper/src/database/reference_data", ] ) from sector_and_industries import * # from app.scraper.src.fetchers.reference_data.tradingview.sector_and_industries import * from cassandra_client import CassandraClient # pylint: disable=import-error default_args = { "owner": "airflow", "depends_on_past": False, "start_date": datetime(2024, 2, 17), "retries": 3, "retry_delay": timedelta(minutes=5), "execution_timeout": timedelta(minutes=60), } # ============================== # 🔹 Step 0: load list countries # ============================== # Function to get country list COUNTRIES_JSON_PATH = "/app/scraper/src/database/reference_data/scraping_raw_json/tradingview/countries_with_flags.json" # COUNTRIES_JSON_PATH = "scraper/src/database/reference_data/scraping_raw_json/tradingview/countries_with_flags.json" def clean_country_name(country: str) -> str: """Convert country name to valid task ID format""" # Replace spaces and special characters with underscores return ( country.lower() .replace(" ", "_") .replace("-", "_") .replace("(", "") .replace(")", "") ) def get_all_countries(): with open(COUNTRIES_JSON_PATH, "r", encoding="utf-8") as file: data = json.load(file) all_countries = [ country["country"].lower() for reg in data for country in reg["countries"] ] print(all_countries) return all_countries all_countries = get_all_countries() # ============================== # 🔹 Step 1: get sectors and industries # ============================== datasets = ["sectors", "industries"] def fetch_tradingview_sector_and_industry_by_country(country, dataset): cassandra_client = CassandraClient() cassandra_client.connect() cassandra_client.create_keyspace() country_source = country.lower().replace(" ", "-") if dataset == "sectors": df = get_tradingview_sectors_by_country(country=country_source) elif dataset == "industries": df = get_tradingview_industries_by_country(country=country_source) view_data(df) if df is None or df.empty: raise ValueError("❌ ERROR: df is None or empty! Check the data source.") table_name = f"tradingview_{dataset}" cassandra_client.insert_df_to_cassandra(df, table_name) cassandra_client.close() # ti = kwargs["ti"] # ti.xcom_push(key=f"{country}_{dataset}_done", value=True) return f"Fetch successfully {dataset} of {country}!!!!" # ============================== # 🔹 Step 2: get components # ============================== # @task(execution_timeout=timedelta(minutes=30), do_xcom_push=False) def fetch_tradingview_sector_and_industry_components_by_country(country, dataset): cassandra_client = CassandraClient() cassandra_client.connect() cassandra_client.create_keyspace() # ti = kwargs["ti"] # country_done = ti.xcom_pull( # task_ids=f"fetch_{clean_country_name(country)}_{dataset}", # key=f"{country}_{dataset}_done", # ) # print(country_done) # if not country_done: # print(f"⚠️ {country} does not have sector/industry data, waiting...") # return f"⏳ Waiting for {country} data..." table_name = f"tradingview_{dataset}" print(table_name) query = f"SELECT * FROM {table_name} WHERE country = %s ALLOW FILTERING " components_df = cassandra_client.query_to_dataframe(query, (str(country),)) if components_df.empty: print(f"⚠️ No data found for country: {country} in {dataset}. Skipping...") cassandra_client.close() return f"No data for {country}, skipping..." view_data(components_df) components_links = components_df["component_url"].tolist() components_arr = [] with ThreadPoolExecutor(max_workers=5) as executor: components_arr = list( executor.map( lambda link: ( print(f"Fetching components from {link}") or time.sleep(random.randint(1, 10)) or get_tradingview_sectors_industries_components(link).assign( component_url=link ) if link else pd.DataFrame() ), components_links, ) ) components_df = pd.concat(components_arr, ignore_index=True) view_data(components_df) table_name = f"tradingview_icb_components" cassandra_client.insert_df_to_cassandra(components_df, table_name) cassandra_client.close() return f"fetch successfully {dataset} components of {country}!!!!" def fetch_components_by_list(country, dataset, component_urls): """Fetch components for a subset of component_urls""" cassandra_client = CassandraClient() cassandra_client.connect() cassandra_client.create_keyspace() components_arr = [] with ThreadPoolExecutor(max_workers=5) as executor: components_arr = list( executor.map( lambda link: ( print(f"Fetching components from {link}") or time.sleep(random.randint(1, 10)) or get_tradingview_sectors_industries_components(link).assign( component_url=link ) if link else pd.DataFrame() ), component_urls, ) ) components_df = pd.concat(components_arr, ignore_index=True) if not components_df.empty: table_name = f"tradingview_icb_components" cassandra_client.insert_df_to_cassandra(components_df, table_name) cassandra_client.close() return ( f"Fetched components for {len(component_urls)} links in {dataset} of {country}" ) # ============================== # 🔹 Step 4: concat and save # ============================== def fetch_components_batch(links, country, dataset, batch_index): """Lấy components cho một lô links""" try: cassandra_client = CassandraClient() cassandra_client.connect() cassandra_client.create_keyspace() components_arr = [] with ThreadPoolExecutor(max_workers=5) as executor: batch_results = list( executor.map( lambda link: ( logging.info(f"Fetching components from {link}"), time.sleep(random.randint(1, 3)), # Giảm thời gian sleep ( get_tradingview_sectors_industries_components(link).assign( component_url=link ) if link else pd.DataFrame() ), )[-1], links, ) ) components_arr.extend(batch_results) if not components_arr: logging.warning( f"No components fetched for batch {batch_index} of {dataset} in {country}" ) cassandra_client.close() return f"No components for batch {batch_index} of {dataset} in {country}" components_df = pd.concat(components_arr, ignore_index=True) if components_df.empty: logging.warning( f"Empty components DataFrame for batch {batch_index} of {dataset} in {country}" ) cassandra_client.close() return f"Empty components for batch {batch_index} of {dataset} in {country}" table_name = "tradingview_icb_components" cassandra_client.insert_df_to_cassandra(components_df, table_name) cassandra_client.close() return components_df except Exception as e: logging.error( f"Error fetching components for batch {batch_index} of {dataset} in {country}: {str(e)}" ) raise def concat_and_save(**kwargs): ti = kwargs["ti"] all_sectors = [] all_sectors_components = [] all_industries = [] all_industries_components = [] for country in all_countries: country = clean_country_name(country) sector_data = ti.xcom_pull( task_ids=f"fetch_sectors_and_industries_tasks.fetch_{clean_country_name(country)}_sectors", key=f"{clean_country_name(country)}_{dataset}", ) sector_component_data = ti.xcom_pull( task_ids=f"fetch_components_tasks.fetch_{clean_country_name(country)}_sectors_components", key=f"{clean_country_name(country)}_{dataset}_components", ) industry_data = ti.xcom_pull( task_ids=f"fetch_sectors_and_industries_tasks.fetch_{clean_country_name(country)}_industries", key=f"{clean_country_name(country)}_{dataset}", ) industry_component_data = ti.xcom_pull( task_ids=f"fetch_components_tasks.fetch_{clean_country_name(country)}_industries_components", key=f"{clean_country_name(country)}_{dataset}_components", ) all_sectors.append(sector_data) all_sectors_components.append(sector_component_data) all_industries.append(industry_data) all_industries_components.append(industry_component_data) final_sectors = pd.concat(all_sectors, ignore_index=True) final_sectors_components = pd.concat(all_sectors_components, ignore_index=True) final_industries = pd.concat(all_industries, ignore_index=True) final_industries_components = pd.concat( all_industries_components, ignore_index=True ) print(final_sectors) print(final_sectors_components) print(final_industries) print(final_industries_components) # Define DAG with DAG( "trading_view_get_sectors_and_industries", default_args=default_args, # schedule_interval=timedelta(minutes=90), schedule_interval=None, catchup=False, concurrency=32, max_active_tasks=32, ) as dag: list_countries_task = PythonOperator( task_id="get_list_countries", python_callable=get_all_countries, ) # @task # def process_country(country): # with TaskGroup( # group_id=f"fetch_sector_industry_of_{clean_country_name(country)}" # ) as tg: # sectors_task = fetch_tradingview_sector_and_industry_by_country( # country, "sectors" # ) # industries_task = fetch_tradingview_sector_and_industry_by_country( # country, "industries" # ) # def split_urls(urls, num_chunks=3): # if not urls: # return [[]] * num_chunks # chunk_size = max(1, len(urls) // num_chunks) # return [ # urls[i : i + chunk_size] for i in range(0, len(urls), chunk_size) # ][:num_chunks] # @task # def prepare_sector_components(sectors_result): # if ( # sectors_result # and isinstance(sectors_result, str) # and "Fetch successfully" in sectors_result # ): # cassandra_client = CassandraClient() # cassandra_client.connect() # query = "SELECT component_url FROM tradingview_sectors WHERE country = %s ALLOW FILTERING" # df = cassandra_client.query_to_dataframe(query, (country,)) # cassandra_client.close() # return split_urls(df["component_url"].tolist()) # return [[]] * 3 # @task # def prepare_industry_components(industries_result): # if ( # industries_result # and isinstance(industries_result, str) # and "Fetch successfully" in industries_result # ): # cassandra_client = CassandraClient() # cassandra_client.connect() # query = "SELECT component_url FROM tradingview_industries WHERE country = %s ALLOW FILTERING" # df = cassandra_client.query_to_dataframe(query, (country,)) # cassandra_client.close() # return split_urls(df["component_url"].tolist()) # return [[]] * 3 # sector_url_chunks = prepare_sector_components(sectors_task) # industry_url_chunks = prepare_industry_components(industries_task) # sector_component_tasks = [ # fetch_components_by_list.override( # task_id=f"fetch_sector_components_chunk_{i+1}" # )(country, "sectors", sector_url_chunks[i]) # for i in range(3) # ] # industry_component_tasks = [ # fetch_components_by_list.override( # task_id=f"fetch_industry_components_chunk_{i+1}" # )(country, "industries", industry_url_chunks[i]) # for i in range(3) # ] # sectors_task >> sector_url_chunks >> sector_component_tasks # industries_task >> industry_url_chunks >> industry_component_tasks # return f"Completed fetch sector and industry of {country}" # process_tasks = process_country.expand(country=all_countries) fetch_sectors_and_industries_tasks = [] fetch_components_tasks = [] with TaskGroup( "fetch_sectors_and_industries_tasks" ) as fetch_sectors_and_industries_group: for country in all_countries: print(country) for dataset in datasets: PythonOperator( task_id=f"fetch_{clean_country_name(country)}_{dataset}", python_callable=fetch_tradingview_sector_and_industry_by_country, op_args=[country, dataset], ) country_groups = [] batch_size = 5 for country in all_countries: cleaned_country = clean_country_name(country) with TaskGroup(f"fetch_{cleaned_country}_tasks") as country_group: for dataset in datasets: fetch_components_urls_task = PythonOperator( task_id=f"fetch_{cleaned_country}_{dataset}_components_urls", python_callable=fetch_tradingview_sector_and_industry_components_by_country, op_args=[country, dataset], ) components_urls = fetch_components_urls_task.execute(context={}) if not components_urls: continue batch_groups = [] for batch_index, i in enumerate( range(0, len(components_urls), batch_size) ): batch_links = components_urls[i : i + batch_size] with TaskGroup( f"fetch_{cleaned_country}_{dataset}_batch_{batch_index}" ) as batch_group: batch_task = PythonOperator( task_id=f"fetch_{cleaned_country}_{dataset}_components_batch_{batch_index}", python_callable=fetch_components_batch, op_args=[batch_links, country, dataset, batch_index], ) batch_groups.append(batch_group) if batch_groups: fetch_components_urls_task >> batch_groups country_groups.append(country_group) list_countries_task >> fetch_sectors_and_industries_group >> country_groups def cleanup(): logging.info("Closing Cassandra connection.") cassandra_client.close() dag.on_success_callback = cleanup dag.on_failure_callback = cleanup
Editor is loading...
Leave a Comment